app.py
app.py
is the most important file in the framework. It contains two functions:
init()
and inference()
init()
is ran once, at server startup. Use it to instantiate slow-to-load resources, such as:
- Models
- Tokenizers
- Third party service clients (AWS SDK, for example)
Load these objects using the
global
keyword in python so that the inference handler may reference them at runtimeIf your model is Pytorch based, load that model into the
global model
variable to make your project compatible with Banana's fast-boot optimizations. More Optimization details here.
inference()
is the handler for running inferences against your preloaded model, and can be ran many times.It receives
model_inputs
from the SDK as a python dictionary, and returns model_outputs
as a python dictionary.Here is an example app.py configured to run a HuggingFace BERT model:
from transformers import pipeline
import torch
# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
global model
device = 0 if torch.cuda.is_available() else -1
model = pipeline('fill-mask', model='bert-base-uncased', device=device)
# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs:dict) -> dict:
global model
# Parse out your arguments
prompt = model_inputs.get('prompt', None)
if prompt == None:
return {'message': "No prompt provided"}
# Run the model
result = model(prompt)
# Return the results as a dictionary
return result
Key takeaways:
- we load the model object to GPU in init(), so it only runs once
- we reference that model for every inference
To fully understand the framework, take the below simplified
app.py
import time
def init():
global a
time.sleep(10)
a = "hello"
print(a)
def inference(model_inputs:dict) -> dict:
global a
b = a + " world"
print(b)
return {"greeting": b}
Imagine we call this server 5 times concurrently and tune the autoscaler to limit max replicas to 1.
Last modified 3mo ago