I am trying to implement a decorator which injects DBConnection
. The problem I am facing is that I would like to support both: passing the argument and depending on decorator to inject it. I have tried doing this with @overload
but failed. Here is reproducible code:
from functools import wraps
from typing import Awaitable, Callable, Concatenate, ParamSpec, TypeVar
from typing_extensions import reveal_type
class DBConnection:
...
T = TypeVar("T")
P = ParamSpec("P")
def inject_db_connection(
f: Callable[Concatenate[DBConnection, P], Awaitable[T]]
) -> Callable[P, Awaitable[T]]:
@wraps(f)
async def inner(*args: P.args, **kwargs: P.kwargs) -> T:
signature = inspect.signature(f).parameters
passed_args = dict(zip(signature, args))
if "db_connection" in kwargs or "db_connection" in passed_args:
return await f(*args, **kwargs)
return await f(DBConnection(), *args, **kwargs)
return inner
@inject_db_connection
async def get_user(db_connection: DBConnection, user_id: int) -> dict:
assert db_connection
return {"user_id": user_id}
async def main() -> None:
# ↓ No issue, great!
user1 = await get_user(user_id=1)
# ↓ Understandably fails with:
# `Unexpected keyword argument "db_connection" for "get_user" [call-arg]`
# but I would like to support passing `db_connection` explicitly as well.
db_connection = DBConnection()
user2 = await get_user(db_connection=db_connection, user_id=1)
# ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
reveal_type(user1)
# ↓ Revealed type is "builtins.dict[Any, Any]", perfect.
reveal_type(user2)```
Using the trick described in this question's answers: python typing signature (typing.Callable) for function with kwargs
I managed to write a version that passes type checking, although it does require a single cast, I don't know how that could be avoided here. I also took the liberty to change the actual functional part inside of inner
to be more robust.
class DBConnection:
...
T = TypeVar("T", covariant=True)
P = ParamSpec("P")
class CallMaybeDB(Protocol[P, T]):
@overload
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> T: ...
@overload
def __call__(self, db_connection: DBConnection, *args: P.args, **kwargs: P.kwargs) -> T: ...
# Unsure if this is required, my IDE complains if it's missing, but mypy doesn't
def __call__(self, *args, **kwargs) -> T: ...
def inject_db_connection(
f: Callable[Concatenate[DBConnection, P], T]
) -> CallMaybeDB[P, T]:
signature = inspect.signature(f)
if "db_connection" not in signature.parameters:
raise TypeError("Function should expect db_connection parameter")
@wraps(f)
def inner(*args, **kwargs) -> T:
bound = signature.bind_partial(*args, **kwargs)
if "db_connection" not in bound.arguments:
bound.arguments["db_connection"] = DBConnection()
return f(*bound.args, **bound.kwargs)
return cast(CallMaybeDB[P, T], inner)