I have a function that is doing some computation and at a certain point is calling another one. For example, the main function is something like:
import numba
@numba.njit(some signature here)
def my_funct():
...
value = cosd(angle)
Since the function cosd
is inside another function decorated with numba.njit
, it has to be decorated as well, and in my case it is:
from numba import float64
@numba.njit(float64(float64))
def cosd(angle):
return np.cos(np.radians(angle))
My problem now is that in another function, the input value angle
is an array and the related output is an array as well. I know that I could decorate my function as @numba.njit(float64[:](float64[:]))
but doing so the function would not accept scalars anymore. How can I can tell numba
that input is something like Union[float64, float64[:]]
? Of course this applies to the output as well. Thanks a lot!
I finally found an answer myself. The solution is to create a list of signatures so, for my example, it would be:
from numba import float64
@njit([float64(float64), float64[:](float64[:])])
def cosd(angle):
return np.cos(np.radians(angle))
I hope this will be helpful to others.