I use an object that needs a startup and a teardown process (load from/save to cache, for example) in a FastAPI endpoints. I used a asynccontextmanager
to manage the context of an object, but I also want to process the object in a later background task.
Now in my environments (fastapi==0.115.5
) the context of this object ends before responding the request, but this is typically earlier than the end of background task, so some of the background task is executed out of the context. For example, if there is a "save to cache" process in the teardown part of the context manager, the later changes in the background task would not be saved, because it runs after the teardown process.
There is a minimal (but still ~150 lines) working example on this gist. I'll also paste it here.
from fastapi import FastAPI, Depends, BackgroundTasks, Request
from typing import Annotated, AsyncIterator
from pydantic import BaseModel, Field
from uuid import uuid4
from contextlib import asynccontextmanager
import random
import asyncio
app = FastAPI()
class Chat(BaseModel):
"""
This is a over-simplified Chat History Manager, that can be used in e.g. LangChain-like system
There is an additional `total` field because history are serialized and cached on their own, and we don't want to load all histories when unserialize them from cache/database.
"""
id: str = Field(default_factory=lambda: uuid4().hex)
meta: str = "some meta information"
history: list[str] = []
total: int = 0
uncached: int = 0
def add_message(self, msg: str):
self.history.append(msg)
self.total += 1
self.uncached += 1
async def save(self, cache: dict):
# cache history that are not cached
for imsg in range(-self.uncached, 0):
cache[f"msg:{self.id}:{self.total + imsg}"] = self.history[-self.uncached]
self.uncached = 0
# cache everything except history
cache[f"sess:{self.id}"] = self.model_dump(exclude={"history"})
print(f"saved: {self}")
@classmethod
async def load(cls, sess_id: str, cache: dict, max_read: int = 30):
sess_key = f"sess:{sess_id}"
obj = cls.model_validate(cache.get(sess_key))
for imsg in range(max(0, obj.total - max_read), obj.total):
obj.history.append(cache.get(f"msg:{obj.id}:{imsg}"))
print(f"loaded: {obj}")
return obj
async def chat(self, msg: str, cache: dict):
"""So this"""
self.add_message(msg)
async def get_chat():
resp = []
for i in range(random.randint(3, 5)):
# simulate long network IO
await asyncio.sleep(0.5)
chunk = f"resp{i}:{random.randbytes(2).hex()};"
resp.append(chunk)
yield chunk
self.add_message("".join(resp))
# NOTE to make the message cache work properly, we have to manually save this:
# await self.save(cache)
return get_chat()
# use a simple dict to mimic an actual cache, e.g. Redis
cache = {}
async def get_cache():
return cache
# didn't figure out how to make Chat a dependable
# I have read https://fastapi.tiangolo.com/advanced/advanced-dependencies/#parameterized-dependencies but still no clue
# the problem is: `sess_id` is passed from user, not something we can fix just like this tutorial shows.
# As an alternative, I used this async context manager.
# Theoretically this would automatically save the Chat object after exiting the `async with` block
@asynccontextmanager
async def get_chat_from_cache(sess_id: str, cache: dict):
"""
get object from cache (possibly create one), yield it, then save it back to cache
"""
sess_key = f"sess:{sess_id}"
if sess_key not in cache:
obj = Chat()
obj.id = sess_id
await obj.save(cache)
else:
obj = await Chat.load(sess_id, cache)
yield obj
await obj.save(cache)
async def task(sess_id: str, task_id: int, resp_gen: AsyncIterator[str], cache: dict):
""" """
async for chunk in resp_gen:
# do something with chunk, e.g. stream it to the client via a websocket
await asyncio.sleep(0.5)
cache[f"chunk:{sess_id}:{task_id}"] = chunk
task_id += 1
@app.get("/{sess_id}/{task_id}/{prompt}")
async def get_chat(
req: Request,
sess_id: str,
task_id: int,
prompt: str,
background_task: BackgroundTasks,
cache: Annotated[dict, Depends(get_cache)],
):
print(f"req incoming: {req.url}")
async with get_chat_from_cache(sess_id=sess_id, cache=cache) as chat:
resp_gen = await chat.chat(f"prompt:{prompt}", cache=cache)
background_task.add_task(
task, sess_id=sess_id, task_id=task_id, resp_gen=resp_gen, cache=cache
)
return "success"
@app.get("/{sess_id}")
async def get_sess(
req: Request, sess_id: str, cache: Annotated[dict, Depends(get_cache)]
):
print(f"req incoming: {req.url}")
return (await Chat.load(sess_id=sess_id, cache=cache)).model_dump()
I found a close (but not identical) discussion that talks about the lifespan of dependables. It seems the lifespan of dependable
could be relayed/extended to into the background tasks, though they think this is a wield behavior. I did have the thought of making the get_chat_from_cache
a yield based dependable, though I didn't figure out how to do it correctly. But anyway, this approach seems not recommended by FastAPI devs, because the actual timing of teardown of dependables are undocumented behaviors and might change in future versions.
I know I could probably manually repeat a teardown
process in the background task, but this seems like a hack. I'm asking if there are more elegant ways to do this. Perhaps there are better design patterns that can avoid this issue completely, please let me know.
Background tasks are executed after your endpoint has finished execution. Thus, you cannot keep the context manager open until the background task is completed.
Turning get_chat_from_cache
into a dependency will not help you (it worked before FastAPI 0.106.0 but the behavior was changed and now you can not use dependencies with yield in background tasks).
You need to re-design your app considering this..