pythonpython-asynciopython-typing

Properly typing a Python decorator with ParamSpec and Concatenate that allows for arbitrary argument positioning?


I have an existing Python decorator that ensures that a method is given a psycopg AsyncConnection instance. I'm trying to update the typing to use ParamSpec and Concatenate as the current implementation isn't typesafe, but I'm getting stuck.

Here's the current implementation:

def ensure_conn(func: Callable[..., Coroutine[Any, Any, R]]) -> Callable[..., Coroutine[Any, Any, R]]:
    """Ensure the function has a conn argument. If conn is not provided, generate a new connection and pass it to the function."""

    async def wrapper(*args: Any, **kwargs: Any) -> R:
        # Get named keyword argument conn, or find an AsyncConnection in the args
        kwargs_conn = kwargs.get("conn")
        conn_arg: AsyncConnection[Any] | None = None
        if isinstance(kwargs_conn, AsyncConnection):
            conn_arg = kwargs_conn
        elif not conn_arg:
            for arg in args:
                if isinstance(arg, AsyncConnection):
                    conn_arg = arg
                    break
        if conn_arg:
            # If conn is provided, call the method as is
            return await func(*args, **kwargs)
        else:
            # If conn is not provided, generate a new connection and pass it to the method
            db_driver = DbDriver()
            async with db_driver.connection() as conn:
                return await func(*args, **kwargs, conn=conn)

    return wrapper

Current usage:

@ensure_conn
async def get_user(user_id: UUID, conn: AsyncConnection):
    async with conn.cursor() as cursor:
    // do stuff

...but I can call this and it won't fail typechecking:

get_user('519766c5-af86-47ea-9fa9-cee0c0de66b1', conn, arg_that_should_fail_typing)

Here's the closest current implementation I've gotten to with ParamSpec and Concatenate:

def ensure_conn_decorator[**P, R](func: Callable[Concatenate[AsyncConnection[Any], P], R]) -> Coroutine[Any, Any, R]:
    """Ensure the function has a conn argument. If conn is not provided, generate a new connection and pass it to the function."""
    async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
        # Get named keyword argument conn, or find an AsyncConnection in the args
        kwargs_conn = kwargs.get("conn")
        conn_arg: AsyncConnection[Any] | None = None
        if isinstance(kwargs_conn, AsyncConnection):
            conn_arg = kwargs_conn
        elif not conn_arg:
            for arg in args:
                if isinstance(arg, AsyncConnection):
                    conn_arg = arg
                    break
        if conn_arg:
            # If conn is provided, call the method as is
            return await func(*args, **kwargs)
        else:
            # If conn is not provided, generate a new connection and pass it to the method
            db_driver = DbDriver()
            async with db_driver.connection() as conn:
                return await func(*args, **kwargs, conn=conn)

    return wrapper

Problems are

Expression of type "(**P@ensure_conn_decorator) -> Coroutine[Any, Any, R@ensure_conn_decorator]" is incompatible with return type "Coroutine[Any, Any, R@ensure_conn_decorator]"
  "function" is incompatible with "Coroutine[Any, Any, R@ensure_conn_decorator]"

Solution

  • Firstly, some notes about limitations of ParamSpec and Concatenate:

    The below example type checks fine for me in pyright/pylance/mypy. I've introduced a few of the classes to check everything type checks - and the return type of "out" is User so it seems fine:

    import asyncio
    from typing import Any, Callable, Concatenate, Coroutine, cast
    import contextlib
    from uuid import UUID
    
    @contextlib.asynccontextmanager
    async def aclosing():
        try:
            yield AsyncConnection()
        finally:
            pass
    
    class AsyncConnection[T]:
        def cursor(self):
            return aclosing()
    
    class DbDriver:
        def connection(self):
            return aclosing()
    
    def ensure_conn[R, **P](func: Callable[Concatenate[AsyncConnection[Any], P], Coroutine[Any, Any, R]]) -> Callable[P, Coroutine[Any, Any, R]]:
        """Ensure the function has a conn argument. If conn is not provided, generate a new connection and pass it to the function."""
    
        async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
            # Get named keyword argument conn, or find an AsyncConnection in the args
            kwargs_conn = kwargs.get("conn")
            conn_arg: AsyncConnection[Any] | None = None
            if isinstance(kwargs_conn, AsyncConnection):
                conn_arg = kwargs_conn
            elif not conn_arg:
                for arg in args:
                    if isinstance(arg, AsyncConnection):
                        conn_arg = arg
                        break
            if conn_arg:
                # If conn is provided, call the method as is
                return await func(conn_arg, *args, **kwargs)
            else:
                # If conn is not provided, generate a new connection and pass it to the method
                db_driver = DbDriver()
                async with db_driver.connection() as conn:
                    return await func(conn, *args, **kwargs)
    
        return wrapper
    
    class User:
        def __init__(self, uuid: UUID) -> None:
            self.uuid=uuid
    
    @ensure_conn
    async def get_user(conn: AsyncConnection, user_id: UUID):
        async with conn.cursor() as cursor:
            return User(user_id)
    
    async def main():
        uuid: UUID = cast(UUID, '519766c5-af86-47ea-9fa9-cee0c0de66b1')
        out = await get_user(uuid)
    
    asyncio.run(main())
    

    Hope this helps!