pythonnumbaaot

Numba AOT compile functions with functional arguments


I am trying to AOT compile in Numba a function which has a functional argument, but I cannot find a way to correctly specify its signature. Using a very basic example, with the standard numba @njit decorator, I would write:

import numba as nb

@nb.njit(nb.f8(nb.f8, nb.f8))
def fcn_sum(a, b): 
    return a + b

@nb.njit(nb.f8(nb.typeof(fcn_sum), nb.f8, nb.f8))
def test(fun, a, b): 
    return fun(a, b)

where nb.typeof(fcn_sum) returns a dispatcher object that is valid only for the fcn_sum function. Unfortunately, the same strategy for AOT compilation generates a NameError error because both nb and typeof are not recognised:

@cc.export('test', 'f8(nb.typeof(fcn_sum), f8, f8)')
def test(fun, a, b):
    return fun(a, b)

How can I specify the signature of functional arguments to make this example work?


Solution

  • There is no error when you use the same signature as in the @njit case:

    @cc.export('test', nb.f8(nb.typeof(fcn_sum), nb.f8, nb.f8))
    def test(fun, a, b):
        return fun(a, b)