I have an existing Python decorator that ensures that a method is given a psycopg AsyncConnection instance. I'm trying to update the typing to use ParamSpec
and Concatenate
as the current implementation isn't typesafe, but I'm getting stuck.
Here's the current implementation:
def ensure_conn(func: Callable[..., Coroutine[Any, Any, R]]) -> Callable[..., Coroutine[Any, Any, R]]:
"""Ensure the function has a conn argument. If conn is not provided, generate a new connection and pass it to the function."""
async def wrapper(*args: Any, **kwargs: Any) -> R:
# Get named keyword argument conn, or find an AsyncConnection in the args
kwargs_conn = kwargs.get("conn")
conn_arg: AsyncConnection[Any] | None = None
if isinstance(kwargs_conn, AsyncConnection):
conn_arg = kwargs_conn
elif not conn_arg:
for arg in args:
if isinstance(arg, AsyncConnection):
conn_arg = arg
break
if conn_arg:
# If conn is provided, call the method as is
return await func(*args, **kwargs)
else:
# If conn is not provided, generate a new connection and pass it to the method
db_driver = DbDriver()
async with db_driver.connection() as conn:
return await func(*args, **kwargs, conn=conn)
return wrapper
Current usage:
@ensure_conn
async def get_user(user_id: UUID, conn: AsyncConnection):
async with conn.cursor() as cursor:
// do stuff
...but I can call this and it won't fail typechecking:
get_user('519766c5-af86-47ea-9fa9-cee0c0de66b1', conn, arg_that_should_fail_typing)
Here's the closest current implementation I've gotten to with ParamSpec
and Concatenate
:
def ensure_conn_decorator[**P, R](func: Callable[Concatenate[AsyncConnection[Any], P], R]) -> Coroutine[Any, Any, R]:
"""Ensure the function has a conn argument. If conn is not provided, generate a new connection and pass it to the function."""
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Get named keyword argument conn, or find an AsyncConnection in the args
kwargs_conn = kwargs.get("conn")
conn_arg: AsyncConnection[Any] | None = None
if isinstance(kwargs_conn, AsyncConnection):
conn_arg = kwargs_conn
elif not conn_arg:
for arg in args:
if isinstance(arg, AsyncConnection):
conn_arg = arg
break
if conn_arg:
# If conn is provided, call the method as is
return await func(*args, **kwargs)
else:
# If conn is not provided, generate a new connection and pass it to the method
db_driver = DbDriver()
async with db_driver.connection() as conn:
return await func(*args, **kwargs, conn=conn)
return wrapper
Problems are
Expression of type "(**P@ensure_conn_decorator) -> Coroutine[Any, Any, R@ensure_conn_decorator]" is incompatible with return type "Coroutine[Any, Any, R@ensure_conn_decorator]"
"function" is incompatible with "Coroutine[Any, Any, R@ensure_conn_decorator]"
Firstly, some notes about limitations of ParamSpec
and Concatenate
:
Concatenate
requires you to always put the argument that is to be taken first in the list, and the final arg must be the ParamSpec
.
A good way to think about it is that the wrapper "consumes" the first argument, and always expects a function that has a certain structure.
The below example type checks fine for me in pyright/pylance/mypy. I've introduced a few of the classes to check everything type checks - and the return type of "out" is User
so it seems fine:
import asyncio
from typing import Any, Callable, Concatenate, Coroutine, cast
import contextlib
from uuid import UUID
@contextlib.asynccontextmanager
async def aclosing():
try:
yield AsyncConnection()
finally:
pass
class AsyncConnection[T]:
def cursor(self):
return aclosing()
class DbDriver:
def connection(self):
return aclosing()
def ensure_conn[R, **P](func: Callable[Concatenate[AsyncConnection[Any], P], Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
"""Ensure the function has a conn argument. If conn is not provided, generate a new connection and pass it to the function."""
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
# Get named keyword argument conn, or find an AsyncConnection in the args
kwargs_conn = kwargs.get("conn")
conn_arg: AsyncConnection[Any] | None = None
if isinstance(kwargs_conn, AsyncConnection):
conn_arg = kwargs_conn
elif not conn_arg:
for arg in args:
if isinstance(arg, AsyncConnection):
conn_arg = arg
break
if conn_arg:
# If conn is provided, call the method as is
return await func(conn_arg, *args, **kwargs)
else:
# If conn is not provided, generate a new connection and pass it to the method
db_driver = DbDriver()
async with db_driver.connection() as conn:
return await func(conn, *args, **kwargs)
return wrapper
class User:
def __init__(self, uuid: UUID) -> None:
self.uuid=uuid
@ensure_conn
async def get_user(conn: AsyncConnection, user_id: UUID):
async with conn.cursor() as cursor:
return User(user_id)
async def main():
uuid: UUID = cast(UUID, '519766c5-af86-47ea-9fa9-cee0c0de66b1')
out = await get_user(uuid)
asyncio.run(main())
Hope this helps!