pythonnumpypython-typingpython-decorators

Type annotate decorator that changes decorated function arguments


I want to design a decorator that will allow the wrapped method to take a float or a numpy array of floats. If the passed argument was a float then a float should be returned and if it was a numpy array then a numpy array should be returned.

Below is my MWE and latest attempt. I am using VSCode with pylance version v2024.3.2 and Python version 3.12.3. I have a large number of functions I'd like to apply this decorator to. If I was only dealing with one or two functions I could do away with the decorator approach entirely and use @overload.

The type error I get is the following:

Argument of type "float" cannot be assigned to parameter "a" of type "NDArray[float64]" in function "func"
  "float" is incompatible with "NDArray[float64]"
import numpy as np
from numpy import float64
from numpy.typing import NDArray
from collections.abc import Callable
from typing import TypeAlias, TypeVar

T = TypeVar('T', float, NDArray[float64])
PreWrapFunc: TypeAlias = Callable[[NDArray[float64]], NDArray[float64]]
PostWrapFunc: TypeAlias = Callable[[T], T]

def my_decorator(method: PreWrapFunc) -> PostWrapFunc:

    def wrapper(arg: T) -> T:
        if isinstance(arg, float):
            result = method(np.array([arg,]))
            return result[0]
        else:
            return method(arg)
    return wrapper

@my_decorator
def func(a: NDArray[float64]) -> NDArray[float64]:
    return a * 2

func(1.0) # 2.0, type error is happening here!
func(np.array([1.0,])) # array([2.])

Solution

  • Define a protocol with overloaded signatures for your wrapper and then use that as the return type of your decorator:

    import numpy as np
    from numpy import float64
    from numpy.typing import NDArray
    from typing import Callable, Protocol, overload
    
    
    class AsFloatOrArray(Protocol):
        @overload
        def __call__(self, arg: float) -> float: ...
        
        @overload
        def __call__(self, arg: NDArray[float64]) -> NDArray[float64]: ...
    
        def __call__(self, arg): ...
    
    
    def my_decorator(
        method: Callable[[NDArray[float64]], NDArray[float64]],
    ) -> AsFloatOrArray:
        @overload
        def wrapper(arg: float) -> float: ...
    
        @overload
        def wrapper(arg: NDArray[float64]) -> NDArray[float64]: ...
        
        def wrapper(arg):
            if isinstance(arg, float):
                return method(np.array([arg]))[0]
            else:
                return method(arg)
        return wrapper
    
    
    @my_decorator
    def func(a: NDArray[float64]) -> NDArray[float64]:
        return a * 2
    
    
    print(func(1.0))
    print(func(np.array([1.0])))
    

    Or, more concisely, write the decorator as a class and provide overloaded signatures for the __call__ method:

    import numpy as np
    from numpy import float64
    from numpy.typing import NDArray
    from typing import Callable, overload
    
    
    class MyDecorator:
        def __init__(
            self,
            method: Callable[[NDArray[float64]], NDArray[float64]],
        ) -> None:
            self.method = method
    
        @overload
        def __call__(self, arg: float) -> float: ...
    
        @overload
        def __call__(self, arg: NDArray[float64]) -> NDArray[float64]: ...
        
        def __call__(self, arg):
            if isinstance(arg, float):
                return self.method(np.array([arg]))[0]
            else:
                return self.method(arg)
    
    
    @MyDecorator
    def func(a: NDArray[float64]) -> NDArray[float64]:
        return a * 2
    
    
    print(func(1.0))
    print(func(np.array([1.0])))