python-3.xnumba

how to define multiple signatures for a function in numba


I have a function that is doing some computation and at a certain point is calling another one. For example, the main function is something like:

import numba

@numba.njit(some signature here)
def my_funct():
    ...
    value = cosd(angle)

Since the function cosd is inside another function decorated with numba.njit, it has to be decorated as well, and in my case it is:

from numba import float64

@numba.njit(float64(float64))
def cosd(angle):
    return np.cos(np.radians(angle))

My problem now is that in another function, the input value angle is an array and the related output is an array as well. I know that I could decorate my function as @numba.njit(float64[:](float64[:])) but doing so the function would not accept scalars anymore. How can I can tell numba that input is something like Union[float64, float64[:]]? Of course this applies to the output as well. Thanks a lot!


Solution

  • I finally found an answer myself. The solution is to create a list of signatures so, for my example, it would be:

    from numba import float64
    
    @njit([float64(float64), float64[:](float64[:])])
    def cosd(angle):
        return np.cos(np.radians(angle))
    

    I hope this will be helpful to others.