pythonpython-typing

Typing a callable with ParamSpec or no args (`Callable[P, Any] | Callable[[], Any]`)


Here's a decorator that accepts a callable (fn: Callable[P, Any]) with the same signature as the function getting wrapped. It works and type checks.

import inspect
from typing import Any, Callable, ParamSpec, TypeVar, Union

P = ParamSpec("P")
T = TypeVar("T")


def deco1(fn: Callable[P, Any]) -> Callable[[Callable[P, T]], Callable[P, T]]:
    def decorator(wrapped_fn: Callable[P, T]) -> Callable[P, T]:
        def inner(*args: P.args, **kwargs: P.kwargs) -> T:
            fn(*args, **kwargs)
            return wrapped_fn(*args, **kwargs)

        return inner

    return decorator


def logger1(x: int, y: str) -> None:
    print(f"logger1: x={x}, y={y}")


@deco1(logger1)
def wrapped1(x: int, y: str) -> None: ...


wrapped1(1, "test")  # prints "logger1: x=1, y=test"

Here's a decorator that accepts a callable (fn: Callable[[], Any]) with no arguments. It also works and type checks.

def deco2(fn: Callable[[], Any]) -> Callable[[Callable[P, T]], Callable[P, T]]:
    def decorator(wrapped_fn: Callable[P, T]) -> Callable[P, T]:
        def inner(*args: P.args, **kwargs: P.kwargs) -> T:
            fn()
            return wrapped_fn(*args, **kwargs)

        return inner

    return decorator


def logger2() -> None:
    print(f"logger2")


@deco2(logger2)
def wrapped2(x: int, y: str) -> None: ...


wrapped2(1, "test")  # prints "logger2"

I'd like to combine deco1 and deco2 into a single function deco3. The following works at run-time, but I can't get the type checker to pass. Is this possible?

def deco3(
    fn: Callable[[], Any] | Callable[P, Any],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
    def decorator(wrapped_fn: Callable[P, T]) -> Callable[P, T]:
        def inner(*args: P.args, **kwargs: P.kwargs) -> T:
            if len(inspect.signature(fn).parameters):
                fn(*args, **kwargs)
            else:
                fn()  # pyright error: Arguments for ParamSpec "P@deco3" are missing (reportCallIssue)
            return wrapped_fn(*args, **kwargs)

        return inner

    return decorator


@deco3(logger1)
def wrapped_with_logger1(x: int, y: str) -> None: ...


@deco3(logger2)  # pyright error: Argument of type "(x: int, y: str) -> None" cannot be assigned to parameter of type "() -> T@deco3"
def wrapped_with_logger2(x: int, y: str) -> None: ...


wrapped_with_logger1(1, "test")  # prints "logger1: x=1, y=test"
wrapped_with_logger2(1, "test")  # prints "logger2" but pyright error: Expected 0 positional arguments (reportCallIssue)

Solution

  • deco1 and deco2 are both decorator factories; they aren't decorators themselves, but rather they produce a decorator. The produced decorator's type is given by the return value of the factory (in both deco1 and deco2: Callable[[Callable[P, T]], Callable[P, T]]). However, note carefully that P doesn't play the same role in deco1 compared with deco2:

    To emphasise, the Ps are in totally different contexts in these two cases. When type variables (including ParamSpecs) are used to mean different things in different contexts, you cannot combine them by cramming them in the same signature with a union.

    The way to handle this is using @typing.overloads, which will correctly separate the 2 contexts:

    import collections.abc as cx
    import inspect
    import typing as t
    
    P = t.ParamSpec("P")
    T = t.TypeVar("T")
    
    
    @t.overload
    def deco3(
        fn: cx.Callable[[], t.Any], /
    ) -> cx.Callable[[cx.Callable[P, T]], cx.Callable[P, T]]:
        """
        Same signature as `deco2`. Deals with the empty-parameters case.
    
        Produces a decorator which preserves the decorated function's signature.
        """
    
    @t.overload
    def deco3(  # pyright: ignore[reportOverlappingOverload]
        fn: cx.Callable[P, t.Any], /
    ) -> cx.Callable[[cx.Callable[P, T]], cx.Callable[P, T]]:
        """
        Same signature as `deco1`. Deals with the non-empty-parameters case.
    
        Produces a decorator which checks and matches the decorated function's signature
            against `fn`.
       
        The addition of `# pyright: ignore` is a pyright bug (the same bug appears in mypy).
        The signatures of overloads 1 and 2 do not actually overlap in any practical manner.
        """
    
    def deco3(  # pyright: ignore[reportInconsistentOverload]
        fn: cx.Callable[..., t.Any],
    ) -> cx.Callable[[cx.Callable[P, T]], cx.Callable[P, T]]:
        """
        For ergonomics, loosely type the implementation's `fn` type.
        """
    
        def decorator(wrapped_fn: cx.Callable[P, T]) -> cx.Callable[P, T]:
            def inner(*args: P.args, **kwargs: P.kwargs) -> T:
                if len(inspect.signature(fn).parameters):
                    fn(*args, **kwargs)
                else:
                    fn()
                return wrapped_fn(*args, **kwargs)
    
            return inner
    
        return decorator
    
    # All passing
    
    def logger1(x: int, y: str) -> None: ...
    def logger2() -> None: ...
    
    @deco3(logger1)
    def wrapped_with_logger1(x: int, y: str) -> None: ...
    
    @deco3(logger2)
    def wrapped_with_logger2(x: int, y: str) -> None: ...
    
    wrapped_with_logger1(1, "test")
    wrapped_with_logger2(1, "test")
    
    # Type-checking in action
    
    @deco3(logger1)  # Fail: extra parameter, doesn't match `logger1`'s signature
    def wrapped_with_logger1_fail(x: int, y: str, z: bytes) -> None: ...
    
    @deco3(logger2)  # Pass: Anything goes - the signature is preserved
    def wrapped_with_logger2_anything_goes(x: int, y: str, z: bytes) -> None: ...
    wrapped_with_logger2_anything_goes(1, "test")  # Fail: missing argument `z: bytes`