pythonfastapimiddlewareasgi

How to unit test a pure ASGI middleware in python


I have an ASGI middleware that adds fields to the POST request body before it hits the route in my fastapi app.

from starlette.types import ASGIApp, Message, Scope, Receive, Send

class MyMiddleware:
    """
    This middleware implements a raw ASGI middleware instead of a starlette.middleware.base.BaseHTTPMiddleware
    because the BaseHTTPMiddleware does not allow us to modify the request body.
    For documentation see https://www.starlette.io/middleware/#pure-asgi-middleware
    """
    def __init__(self, app: ASGIApp):
        self.app = app

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] != "http":
            await self.app(scope, receive, send)
            return ""

        async def modify_message():
            message: dict = await receive()
            if message.get("type", "") != "http.request":
                return message
            if not message.get("body", None):
                return message
            body: dict = json.loads(message.get("body", b"'{}'").decode("utf-8"))
            body["some_field"] = "foobar"
            message["body"] = json.dumps(body).encode("utf-8")
            return message

        await self.app(scope, modify_message, send)

Is there an example on how to unit test an ASGI middleware? I would like to test directly the __call__ part which is difficult as it does not return anything. Do I need to use a test api client (e.g. TestClient from fastapi) to then create some dummy endpoint which returns the request as response and thereby check if the middleware was successful or is there a more "direct" way?


Solution

  • I've faced the similar problem recently, so I want to share my solution for fastapi and pytest.

    I had to implement per request logs for the fastapi app using middlewares.

    I've checked Starlette's test suite as Marcelo Trylesinski suggested and adapted the code to fit fastapi. Thank you for the recommendation, Marcelo!

    Here is my middleware that logs information from every request and response.

    # middlewares.py
    import logging
    
    from starlette.types import ASGIApp, Scope, Receive, Send
    
    
    logger = logging.getLogger("app")
    
    
    class LogRequestsMiddleware:
        def __init__(self, app: ASGIApp) -> None:
            self.app = app
    
        async def __call__(
            self, scope: Scope, receive: Receive, send: Send
        ) -> None:
            async def send_with_logs(message):
                """Log every request info and response status code."""
                if message["type"] == "http.response.start":
                    # request info is stored in the scope
                    # status code is stored in the message
                    logger.info(
                        f'{scope["client"][0]}:{scope["client"][1]} - '
                        f'"{scope["method"]} {scope["path"]} '
                        f'{scope["scheme"]}/{scope["http_version"]}" '
                        f'{message["status"]}'
                    )
                await send(message)
    
            await self.app(scope, receive, send_with_logs)
    

    To test a middleware, I had to create test_factory_client fixture:

    # conftest.py
    import pytest
    
    from fastapi.testclient import TestClient
    
    
    @pytest.fixture
    def test_client_factory() -> TestClient:
        return TestClient
    

    In the test, I mocked logger.info() call within the middleware and asserted if the method was called.

    # test_middlewares.py
    from unittest import mock
    from fastapi.testclient import TestClient
    from fastapi import FastAPI
    from .middlewares import LogRequestsMiddleware
    
    # mock logger call within the pure middleware
    @mock.patch("path.to.middlewares.logger.info")
    def test_log_requests_middleware(
        mock_logger, test_client_factory: TestClient
    ):
        # create a fresh app instance to isolate tested middlewares
        app = FastAPI()
        app.add_middleware(LogRequestsMiddleware)
        
        # create an endpoint to test middlewares
        @app.get("/")
        def homepage():
            return {"hello": "world"}
    
        # create a client for the app using fixure
        client = test_client_factory(app)
    
        # call an endpoint
        response = client.get("/")
    
        # sanity check
        assert response.status_code == 200
        # check if the logger was called
        mock_logger.assert_called_once()