I would like to make an FastAPI
service with one /get
endpoint which will return a ML-model inference result. It is pretty easy to implement that, but the catch is I periodically need to update the model with a newer version (trough request on another server with models, but that is beside the point), and here I see a problem!
What will happen if one request calls old model, but the old model is currently being replaced by a newer one?? How can I implement this kind of locking mechanism with asyncio
?
Here is the code:
import asyncio
import time
from concurrent.futures import ProcessPoolExecutor
from fastapi import FastAPI, Request
from sentence_transformers import SentenceTransformer
app = FastAPI()
sbertmodel = None
def create_model():
global sbertmodel
sbertmodel = SentenceTransformer('multi-qa-MiniLM-L6-cos-v1')
# if you try to run all predicts concurrently, it will result in CPU trashing.
pool = ProcessPoolExecutor(max_workers=1, initializer=create_model)
def model_predict():
ts = time.time()
vector = sbertmodel.encode('How big is London')
return vector
async def vector_search(vector):
# simulate I/O call (e.g. Vector Similarity Search using a VectorDB)
await asyncio.sleep(0.005)
@app.get("/")
async def entrypoint(request: Request):
loop = asyncio.get_event_loop()
ts = time.time()
# worker should be initialized outside endpoint to avoid cold start
vector = await loop.run_in_executor(pool, model_predict)
print(f"Model : {int((time.time() - ts) * 1000)}ms")
ts = time.time()
await vector_search(vector)
print(f"io task: {int((time.time() - ts) * 1000)}ms")
return "ok"
My model update would be implemented trough Repeated tasks (but that is not important now) : https://fastapi-utils.davidmontague.xyz/user-guide/repeated-tasks/
This is the idea of a model serving : https://luis-sena.medium.com/how-to-optimize-fastapi-for-ml-model-serving-6f75fb9e040d
EDIT: what is important to run multiple requests concurrently, and while model is updating, acquire lock so that requests wouldnt fail, they should just wait a little bit longer because it is a small model.
Thanks for your snippet. With it visible, it is possible to write a proposal for what you need there - as it turns out, you need to update the model in a subprocess, and there is nothing to worry about in the main-process async part of the code. Signaling the worker processes for the updates, though, needs some attention.
Since you are using ProcessPool workers, you need a way to expose variables from the root process that the process workers can "see" -
Python has this in the form of
multiprocessing.Manager
objects -
Bellow i pick your code and add the parts needed for your requisites of "no imediate, but no conflicting" updatding of the model in use. As it turns out, once we have variables that can be seen in the worker, all that is needed is a check in the model-runner method itself to see if the model needs to be updated.\
I didn't run this snippet - so there might be some typo in variable names or even one or other missing parenthesis - use as a model,
not "copy + paste" (but I tested the "moving parts"
of Manager.Namespace()
objects and passing then as parameters
as initargs
in a ProcessPoolExecutor
)
import asyncio
import time
import threading
from concurrent.futures import ProcessPoolExecutor
from multiprocessing import Manager
from fastapi import FastAPI, Request
from sentence_transformers import SentenceTransformer
sbertmodel = None
local_model_iteration = -1
shared_namespace = None
# pool, and other multi-processing objects can`t simply
# be started in the top level of the body, or they't be re
# created in each subprocess!!
# check https://fastapi.tiangolo.com/advanced/events/#lifespan
@asynccontextmanager
async def lifespan(app: FastAPI):
global pool, root_namespace
manager = Manager()
root_namespace = manager.NameSpace()
# Values assigned to the "namespace" object are
# visible on the subprocess created by the pool
root_namspace.model_iteration = 0
root_namespace.model_parameters = "multi-qa-MiniLM-L6-cos-v1"
# (as long as we send the namespace object to each subprocess
# and store it there)
pool = ProcessPoolExecutor(max_workers=1, initializer=initialize_subprocess, initargs=(root_namespace,))
with pool, manager:
# pass control to fastapi: all the app is executed
yield
# end of "with" block:
# both the pool and manager are shutdown when fastapi server exits!
app = FastAPI(lifespan=lifespan)
# if you try to run all predicts concurrently, it will result in CPU trashing.
def initialize_subprocess(shared_namespace_arg):
global shared_namespace
# Store the shared namespace in _this_ process:
shared_namespace = shared_namespac_arg
update_model()
def update_model():
"called on worker subprocess start, and at any time the model is outdated"
global local_model_iteration, sbertmodel
local_model_iteration = shared_namespace.model_iteration
# retrieve parameter posted by root process:
sbertmodel = SentenceTransformer(shared_namespace.model_parameters)
def model_predict():
ts = time.time()
# verify if model was updatd from the root process
if shared_namespace.model_iteration > local_model_iteration:
# if so, just update the model
update_model()
# model is synchronied, just do our job:
vector = sbertmodel.encode('How big is London')
return vector
async def vector_search(vector):
# simulate I/O call (e.g. Vector Similarity Search using a VectorDB)
await asyncio.sleep(0.005)
@app.get("/")
async def entrypoint(request: Request):
loop = asyncio.get_event_loop()
ts = time.time()
# worker should be initialized outside endpoint to avoid cold start
vector = await loop.run_in_executor(pool, model_predict)
print(f"Model : {int((time.time() - ts) * 1000)}ms")
ts = time.time()
await vector_search(vector)
print(f"io task: {int((time.time() - ts) * 1000)}ms")
return "ok"
@app.get("/update_model")
async def update_model_endpoint(request: Request):
# extract from the request the needed paramters for the new model
...
new_model_parameters = ...
# uodate the model parameters and model iteration so they are visible
# in the worker(s)
root_namespace.model_parameters = new_model_parameters
# This increment taking place _after_ the "model_parameters" are set
# is all that is needed to keep things running in order here:
root_namespace.model_iteration += 1
return {} # whatever response needed by the endpoint