pythonfastapistarlettepython-contextvarsfastapi-middleware

Keep context vars values between FastAPI/starlette middlewares depending on the middleware order


I am developing a FastAPI app, and my goal is to record some information in a Request scope and then reuse this information later in log records.

My idea was to use context vars to store the "request context", use a middleware to manipulate the request and set the context var, and finally use a LogFilter to attach the context vars values to the LogRecord.

This is my app skeleton

logger = logging.getLogger(__name__)
app = FastAPI()
app.add_middleware(SetterMiddlware)
app.add_middleware(FooMiddleware)

@app.get("/")
def read_root(setter = Depends(set_request_id)):
    print("Adding req_id to body", req_id.get()) # This is 1234567890
    logging.info("hello")
    return {"Req_id": str(req_id.get())}

and those are my middlewares

class SetterMiddlware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        calculated_id = "1234567890"
        req_id.set(calculated_id)
        request.state.req_id = calculated_id
        response = await call_next(request)
        return response

class FooMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        response = await call_next(request)
        return response

and the Logging Filter

from vars import req_id


class CustomFilter(Filter):
    """Logging filter to attach the user's authorization to log records"""

    def filter(self, record: LogRecord) -> bool:

        record.req_id = req_id.get()
        return True

And finally following a part of my log configuration

...
"formatters": {
        "default": {
            "format": "%(levelname)-9s %(asctime)s [%(req_id)s]| %(message)s",
            "datefmt": "%Y-%m-%d,%H:%M:%S",
        },
    },
    "handlers": {
...
"handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "formatter": "default",
            "stream": "ext://sys.stderr",
            "filters": [
                "custom_filter",
            ],
            "level": logging.NOTSET,
        },
...
"loggers": {
        "": {
            "handlers": ["console"],
            "level": logging.DEBUG,
        },
        "uvicorn": {"handlers": ["console"], "propagate": False},
    },

When SetterMiddlware is the latest added in the app (FooMiddleware commented in the example), my app logs as expected

Adding req_id to body 1234567890
INFO      2025-04-14,15:02:28 [1234567890]| hello
INFO      2025-04-14,15:02:28 [1234567890]| 127.0.0.1:52912 - "GET / HTTP/1.1" 200

But if I add some other middleware after SetterMiddlware, uvicorn logger does not find anymore the context_var req_id set.

Adding req_id to body 1234567890
INFO      2025-04-14,15:03:56 [1234567890]| hello
INFO      2025-04-14,15:03:56 [None]| 127.0.0.1:52919 - "GET / HTTP/1.1" 200

I tried using the package https://starlette-context.readthedocs.io/en/latest/ but I wasn't luckier; it looks like it suffers the same problems.

I would like to know why this behavior happens and how I can fix it, without the constraint of having the SetterMiddleware in the last middleware position.


Solution

  • Currently dealing with a similar setup and I spent some time digging..

    I'm not sure if this truly answers the why in your question but what I found is that if I use a custom middleware class without inheriting from BaseHTTPMiddleware (à la Pure ASGI Middleware) the context variables get propagated correctly to the uvicorn access logger.

    This might have something to do with the known starlette BaseHTTPMiddleware limitation of not propagating contextvars "upwards". IIRC there are also some raised anyio issues related to contextvars...

    So the solution would be along the lines of:

    from starlette.types import ASGIApp, Receive, Scope, Send
    
    class SetterMiddlware:
        def __init__(self, app: ASGIApp) -> None:
            self.app = app
    
        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
            calculated_id = "1234567890"
            req_id.set(calculated_id)
            request = Request(scope, receive)
            request.state.req_id = calculated_id
            response = await self.app(scope, receive, send)
            return response
    
    class FooMiddleware:
        def __init__(self, app: ASGIApp) -> None:
            self.app = app
    
        async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
            response = await self.app(scope, receive, send)
            return response