Here's a decorator that accepts a callable (fn: Callable[P, Any]
) with the same signature as the function getting wrapped. It works and type checks.
import inspect
from typing import Any, Callable, ParamSpec, TypeVar, Union
P = ParamSpec("P")
T = TypeVar("T")
def deco1(fn: Callable[P, Any]) -> Callable[[Callable[P, T]], Callable[P, T]]:
def decorator(wrapped_fn: Callable[P, T]) -> Callable[P, T]:
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
fn(*args, **kwargs)
return wrapped_fn(*args, **kwargs)
return inner
return decorator
def logger1(x: int, y: str) -> None:
print(f"logger1: x={x}, y={y}")
@deco1(logger1)
def wrapped1(x: int, y: str) -> None: ...
wrapped1(1, "test") # prints "logger1: x=1, y=test"
Here's a decorator that accepts a callable (fn: Callable[[], Any]
) with no arguments. It also works and type checks.
def deco2(fn: Callable[[], Any]) -> Callable[[Callable[P, T]], Callable[P, T]]:
def decorator(wrapped_fn: Callable[P, T]) -> Callable[P, T]:
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
fn()
return wrapped_fn(*args, **kwargs)
return inner
return decorator
def logger2() -> None:
print(f"logger2")
@deco2(logger2)
def wrapped2(x: int, y: str) -> None: ...
wrapped2(1, "test") # prints "logger2"
I'd like to combine deco1
and deco2
into a single function deco3
. The following works at run-time, but I can't get the type checker to pass. Is this possible?
def deco3(
fn: Callable[[], Any] | Callable[P, Any],
) -> Callable[[Callable[P, T]], Callable[P, T]]:
def decorator(wrapped_fn: Callable[P, T]) -> Callable[P, T]:
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if len(inspect.signature(fn).parameters):
fn(*args, **kwargs)
else:
fn() # pyright error: Arguments for ParamSpec "P@deco3" are missing (reportCallIssue)
return wrapped_fn(*args, **kwargs)
return inner
return decorator
@deco3(logger1)
def wrapped_with_logger1(x: int, y: str) -> None: ...
@deco3(logger2) # pyright error: Argument of type "(x: int, y: str) -> None" cannot be assigned to parameter of type "() -> T@deco3"
def wrapped_with_logger2(x: int, y: str) -> None: ...
wrapped_with_logger1(1, "test") # prints "logger1: x=1, y=test"
wrapped_with_logger2(1, "test") # prints "logger2" but pyright error: Expected 0 positional arguments (reportCallIssue)
deco1
and deco2
are both decorator factories; they aren't decorators themselves, but rather they produce a decorator. The produced decorator's type is given by the return value of the factory (in both deco1
and deco2
: Callable[[Callable[P, T]], Callable[P, T]]
). However, note carefully that P
doesn't play the same role in deco1
compared with deco2
:
deco1
, P
is used to statically check that the signature of the function being decorated (wrapped1
) matches the signature of the argument to deco1
(which is a callable; fn: Callable[P, Any]
).deco2
, P
is simply used to preserve the signature of the function being decorated (which is wrapped2
).To emphasise, the P
s are in totally different contexts in these two cases. When type variables (including ParamSpec
s) are used to mean different things in different contexts, you cannot combine them by cramming them in the same signature with a union.
The way to handle this is using @typing.overload
s, which will correctly separate the 2 contexts:
import collections.abc as cx
import inspect
import typing as t
P = t.ParamSpec("P")
T = t.TypeVar("T")
@t.overload
def deco3(
fn: cx.Callable[[], t.Any], /
) -> cx.Callable[[cx.Callable[P, T]], cx.Callable[P, T]]:
"""
Same signature as `deco2`. Deals with the empty-parameters case.
Produces a decorator which preserves the decorated function's signature.
"""
@t.overload
def deco3( # pyright: ignore[reportOverlappingOverload]
fn: cx.Callable[P, t.Any], /
) -> cx.Callable[[cx.Callable[P, T]], cx.Callable[P, T]]:
"""
Same signature as `deco1`. Deals with the non-empty-parameters case.
Produces a decorator which checks and matches the decorated function's signature
against `fn`.
The addition of `# pyright: ignore` is a pyright bug (the same bug appears in mypy).
The signatures of overloads 1 and 2 do not actually overlap in any practical manner.
"""
def deco3( # pyright: ignore[reportInconsistentOverload]
fn: cx.Callable[..., t.Any],
) -> cx.Callable[[cx.Callable[P, T]], cx.Callable[P, T]]:
"""
For ergonomics, loosely type the implementation's `fn` type.
"""
def decorator(wrapped_fn: cx.Callable[P, T]) -> cx.Callable[P, T]:
def inner(*args: P.args, **kwargs: P.kwargs) -> T:
if len(inspect.signature(fn).parameters):
fn(*args, **kwargs)
else:
fn()
return wrapped_fn(*args, **kwargs)
return inner
return decorator
# All passing
def logger1(x: int, y: str) -> None: ...
def logger2() -> None: ...
@deco3(logger1)
def wrapped_with_logger1(x: int, y: str) -> None: ...
@deco3(logger2)
def wrapped_with_logger2(x: int, y: str) -> None: ...
wrapped_with_logger1(1, "test")
wrapped_with_logger2(1, "test")
# Type-checking in action
@deco3(logger1) # Fail: extra parameter, doesn't match `logger1`'s signature
def wrapped_with_logger1_fail(x: int, y: str, z: bytes) -> None: ...
@deco3(logger2) # Pass: Anything goes - the signature is preserved
def wrapped_with_logger2_anything_goes(x: int, y: str, z: bytes) -> None: ...
wrapped_with_logger2_anything_goes(1, "test") # Fail: missing argument `z: bytes`