pythonmypypython-typing

Type hinting decorator which injects the value, but also supports passing the value


I am trying to implement a decorator which injects DBConnection. The problem I am facing is that I would like to support both: passing the argument and depending on decorator to inject it. I have tried doing this with @overload but failed. Here is reproducible code:

from functools import wraps
from typing import Awaitable, Callable, Concatenate, ParamSpec, TypeVar

from typing_extensions import reveal_type


class DBConnection:
    ...


T = TypeVar("T")
P = ParamSpec("P")


def inject_db_connection(
    f: Callable[Concatenate[DBConnection, P], Awaitable[T]]
) -> Callable[P, Awaitable[T]]:
    @wraps(f)
    async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
        signature = inspect.signature(f).parameters
        passed_args = dict(zip(signature, args))
        if "db_connection" in kwargs or "db_connection" in passed_args:
            return await f(*args, **kwargs)

        return await f(DBConnection(), *args, **kwargs)

    return inner


@inject_db_connection
async def get_user(db_connection: DBConnection, user_id: int) -> dict:
    assert db_connection
    return {"user_id": user_id}


async def main() -> None:
    # ↓ No issue, great!
    user1 = await get_user(user_id=1)

    # ↓ Understandably fails with:
    # `Unexpected keyword argument "db_connection" for "get_user"  [call-arg]`
    # but I would like to support passing `db_connection` explicitly as well.
    db_connection = DBConnection()
    user2 = await get_user(db_connection=db_connection, user_id=1)

    # ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
    reveal_type(user1)
    # ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
    reveal_type(user2)```


Solution

  • Using the trick described in this question's answers: python typing signature (typing.Callable) for function with kwargs

    I managed to write a version that passes type checking, although it does require a single cast, I don't know how that could be avoided here. I also took the liberty to change the actual functional part inside of inner to be more robust.

    class DBConnection:
        ...
    
    
    T = TypeVar("T", covariant=True)
    P = ParamSpec("P")
    
    
    class CallMaybeDB(Protocol[P, T]):
        @overload
        def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
    
        @overload
        def __call__(self, db_connection: DBConnection, *args: P.args, **kwargs: P.kwargs) -> T: ...
    
        # Unsure if this is required, my IDE complains if it's missing, but mypy doesn't
        def __call__(self, *args, **kwargs) -> T: ...
    
    def inject_db_connection(
            f: Callable[Concatenate[DBConnection, P], T]
    ) -> CallMaybeDB[P, T]:
        signature = inspect.signature(f)
        if "db_connection" not in signature.parameters:
            raise TypeError("Function should expect db_connection parameter")
    
        @wraps(f)
        def inner(*args, **kwargs) -> T:
            bound = signature.bind_partial(*args, **kwargs)
            if "db_connection" not in bound.arguments:
                bound.arguments["db_connection"] = DBConnection()
    
            return f(*bound.args, **bound.kwargs)
    
        return cast(CallMaybeDB[P, T], inner)