I want to design a decorator that will allow the wrapped method to take a float
or a numpy array of floats. If the passed argument was a float
then a float
should be returned and if it was a numpy array then a numpy array should be returned.
Below is my MWE and latest attempt. I am using VSCode with pylance version v2024.3.2 and Python version 3.12.3. I have a large number of functions I'd like to apply this decorator to. If I was only dealing with one or two functions I could do away with the decorator approach entirely and use @overload
.
The type error I get is the following:
Argument of type "float" cannot be assigned to parameter "a" of type "NDArray[float64]" in function "func"
"float" is incompatible with "NDArray[float64]"
import numpy as np
from numpy import float64
from numpy.typing import NDArray
from collections.abc import Callable
from typing import TypeAlias, TypeVar
T = TypeVar('T', float, NDArray[float64])
PreWrapFunc: TypeAlias = Callable[[NDArray[float64]], NDArray[float64]]
PostWrapFunc: TypeAlias = Callable[[T], T]
def my_decorator(method: PreWrapFunc) -> PostWrapFunc:
def wrapper(arg: T) -> T:
if isinstance(arg, float):
result = method(np.array([arg,]))
return result[0]
else:
return method(arg)
return wrapper
@my_decorator
def func(a: NDArray[float64]) -> NDArray[float64]:
return a * 2
func(1.0) # 2.0, type error is happening here!
func(np.array([1.0,])) # array([2.])
Define a protocol with overloaded signatures for your wrapper and then use that as the return type of your decorator:
import numpy as np
from numpy import float64
from numpy.typing import NDArray
from typing import Callable, Protocol, overload
class AsFloatOrArray(Protocol):
@overload
def __call__(self, arg: float) -> float: ...
@overload
def __call__(self, arg: NDArray[float64]) -> NDArray[float64]: ...
def __call__(self, arg): ...
def my_decorator(
method: Callable[[NDArray[float64]], NDArray[float64]],
) -> AsFloatOrArray:
@overload
def wrapper(arg: float) -> float: ...
@overload
def wrapper(arg: NDArray[float64]) -> NDArray[float64]: ...
def wrapper(arg):
if isinstance(arg, float):
return method(np.array([arg]))[0]
else:
return method(arg)
return wrapper
@my_decorator
def func(a: NDArray[float64]) -> NDArray[float64]:
return a * 2
print(func(1.0))
print(func(np.array([1.0])))
Or, more concisely, write the decorator as a class and provide overloaded signatures for the __call__
method:
import numpy as np
from numpy import float64
from numpy.typing import NDArray
from typing import Callable, overload
class MyDecorator:
def __init__(
self,
method: Callable[[NDArray[float64]], NDArray[float64]],
) -> None:
self.method = method
@overload
def __call__(self, arg: float) -> float: ...
@overload
def __call__(self, arg: NDArray[float64]) -> NDArray[float64]: ...
def __call__(self, arg):
if isinstance(arg, float):
return self.method(np.array([arg]))[0]
else:
return self.method(arg)
@MyDecorator
def func(a: NDArray[float64]) -> NDArray[float64]:
return a * 2
print(func(1.0))
print(func(np.array([1.0])))