pythonpython-typing

Wrapper stripping the generic parameter of a function erases its type parameter


In the following Python code, I define a generic function wrapper which takes a function of type T → T and replaces it by a function without arguments returning an instance of Delay[T]. This instance simply stores the original function so that it can be called later.

from collections.abc import Callable

class Delay[T]:
    def __init__(self, wrapped: Callable[[T], T]):
        self.wrapped = wrapped

def wrapper[T](wrapped: Callable[[T], T]) -> Callable[[], Delay[T]]:
    def wrapping() -> Delay[T]:
        return Delay(wrapped)

    return wrapping

When using this wrapper with a normal function, the type checker is happy:

@wrapper
def fun1(arg: str) -> str:
    return arg

reveal_type(fun1) # mypy says: "def () -> Delay[builtins.str]"
reveal_type(fun1()) # mypy says: "Delay[builtins.str]"
reveal_type(fun1().wrapped) # mypy says: "def (builtins.str) -> builtins.str"
reveal_type(fun1().wrapped("test")) # mypy says: "builtins.str"

However, when the wrapped function is generic, the type argument somehow gets erased:

@wrapper
def fun2[T](arg: T) -> T:
    return arg

reveal_type(fun2) # mypy says: "def () -> Delay[Never]"
reveal_type(fun2()) # mypy says: "Delay[Never]"
reveal_type(fun2().wrapped) # mypy says: "def (Never) -> Never"
reveal_type(fun2().wrapped("test")) # mypy says: "Never"

I would have expected the type checker to infer the type of fun2 as def [T] () -> Delay[T], the type of fun2().wrapped as def [T] (T) -> T, and the type of the last line as str.

Note that pyright seems to exhibit similar behavior as mypy here.

Is there something invalid with the type annotations in my code? Is this a known limitation of the Python type system, or a bug in mypy and pyright?


Solution

  • Based on what I think you're trying to do (mypy Playground with a hacky solution), I would say your annotations are invalid - you're trying to using the same symbol T to bind to different type variable scopes.


    You already know that fun1: "def () -> Delay[builtins.str]" here ...

    @wrapper
    def fun1(arg: str) -> str:
        return arg
    

    ... but you cannot have fun2: "def () -> Delay[T]" here.

    @wrapper
    def fun2[T](arg: T) -> T:
        return arg
    

    This is because fun2 is a variable at the module-scope, and module-scoped variables can't have types with a free type variable, because modules don't bind types (only generic classes and generic functions can bind types in their bodies). Something with type Delay[T] at the module scope can't ever be fulfilled; you can't create an instance of T at this scope.

    What you're trying to do might be this:

    Delay[Never] indicates something that can't be parameterised by a concrete type at the module scope. Hence, a workaround is to introduce a descriptor type.

    if TYPE_CHECKING:
    
        class Wrapped:
            @overload  # type: ignore[no-overload-impl]
            def __get__(self, instance: None, owner: type[object], /) -> Self: ...
            @overload
            def __get__[R](
                self, instance: Delay[Never], owner: type[Delay[Never]], /
            ) -> Callable[[R], R]: 
                """
                Can't be parameterised by a concrete type, return a callable which
                just returns the same type as it receives
                """
            @overload
            def __get__[T](
                self, instance: Delay[T], owner: type[Delay[T]], /
            ) -> Callable[[T], T]: 
                """
                Can be parameterised by a concrete type, return a callable which
                receives and returns this concrete type
                """
            def __set__[T](
                self, instance: Delay[Any], value: Callable[[T], T], /
            ) -> None: ...
    
    @wrapper
    def fun1(arg: str) -> str:
        return arg
    
    # `Delay[str]` (parameterised by concrete type `str`)
    reveal_type(fun1().wrapped)  # "def (builtins.str) -> builtins.str"
    
    @wrapper
    def fun2[T](arg: T) -> T:
        return arg
    
    # `Delay[Never]` (can't fulfil parameterisation)
    reveal_type(fun2().wrapped)  # def [R](R) -> R
    

    Full solution below:

    from collections.abc import Callable
    from typing import TYPE_CHECKING, Any, Never, Self, overload
    
    
    if TYPE_CHECKING:
    
        class Wrapped:
            @overload  # type: ignore[no-overload-impl]
            def __get__(self, instance: None, owner: type[object], /) -> Self: ...
            @overload
            def __get__[R](
                self, instance: Delay[Never], owner: type[Delay[Never]], /
            ) -> Callable[[R], R]: 
                """
                Can't be parameterised by a concrete type, return a callable which
                just returns the same type as it receives
                """
            @overload
            def __get__[T](
                self, instance: Delay[T], owner: type[Delay[T]], /
            ) -> Callable[[T], T]: 
                """
                Can be parameterised by a concrete type, return a callable which
                receives and returns this concrete type
                """
            def __set__[T](
                self, instance: Delay[Any], value: Callable[[T], T], /
            ) -> None: ...
    
    
    class Delay[T]:
        if TYPE_CHECKING:
            wrapped = Wrapped()
    
        def __init__(self, wrapped: Callable[[T], T]):
            self.wrapped = wrapped
    
    
    def wrapper[T](wrapped: Callable[[T], T]) -> Callable[[], Delay[T]]:
        def wrapping() -> Delay[T]:
            return Delay(wrapped)
    
        return wrapping
    
    
    @wrapper
    def fun1(arg: str) -> str:
        return arg
    
    
    reveal_type(fun1)  # mypy says: "def () -> Delay[builtins.str]"
    reveal_type(fun1())  # mypy says: "Delay[builtins.str]"
    reveal_type(fun1().wrapped)  # mypy says: "def (builtins.str) -> builtins.str"
    reveal_type(fun1().wrapped("test"))  # mypy says: "builtins.str"
    reveal_type(fun1().wrapped(1))  # Error
    
    
    @wrapper
    def fun2[T](arg: T) -> T:
        return arg
    
    
    reveal_type(fun2)
    reveal_type(fun2())
    reveal_type(fun2().wrapped)  # def [R](R) -> R
    reveal_type(fun2().wrapped("test"))  # str
    reveal_type(fun2().wrapped(1))  # int