Links

app.py

app.py is the most important file in the framework.
It contains two functions: init() and inference()

init()

init() is ran once, at server startup.
Use it to instantiate slow-to-load resources, such as:
  • Models
  • Tokenizers
  • Third party service clients (AWS SDK, for example)
Load these objects using the global keyword in python so that the inference handler may reference them at runtime
If your model is Pytorch based, load that model into the global model variable to make your project compatible with Banana's fast-boot optimizations. More Optimization details here.​
​

inference()

inference() is the handler for running inferences against your preloaded model, and can be ran many times.
It receives model_inputs from the SDK as a python dictionary, and returns model_outputs as a python dictionary.

Example:

Here is an example app.py configured to run a HuggingFace BERT model:
from transformers import pipeline
import torch
​
# Init is ran on server startup
# Load your model to GPU as a global variable here using the variable name "model"
def init():
global model
device = 0 if torch.cuda.is_available() else -1
model = pipeline('fill-mask', model='bert-base-uncased', device=device)
​
# Inference is ran for every server call
# Reference your preloaded global model variable here.
def inference(model_inputs:dict) -> dict:
global model
​
# Parse out your arguments
prompt = model_inputs.get('prompt', None)
if prompt == None:
return {'message': "No prompt provided"}
# Run the model
result = model(prompt)
​
# Return the results as a dictionary
return result
Key takeaways:
  • we load the model object to GPU in init(), so it only runs once
  • we reference that model for every inference

Exercise:

To fully understand the framework, take the below simplified app.py
import time
​
def init():
global a
time.sleep(10)
a = "hello"
print(a)
​
def inference(model_inputs:dict) -> dict:
global a
b = a + " world"
print(b)
​
return {"greeting": b}
Imagine we call this server 5 times concurrently and tune the autoscaler to limit max replicas to 1.
10s
hello
hello world
hello world
hello world
hello world
hello world
{"greeting": "hello world"}
​