pythonnumba

how to force numba to return a numpy type?


I find this behavior quite counter-intuitive although I suppose there is a reason for it - numba automatically converts my numpy integer types directly into a python int:

import numba as nb
import numpy as np 

print(f"Numba version: {nb.__version__}")  # 0.59.0
print(f"NumPy version: {np.__version__}")  # 1.23.5

# Explicitly define the signature
sig = nb.uint32(nb.uint32, nb.uint32)

@nb.njit(sig, cache=False)
def test_fn(a, b):
    return a * b

res = test_fn(2, 10)
print(f"Result value: {res}")  # returns 20
print(f"Result type: {type(res)}")  # returns <class 'int'>

This is an issue as I'm using the return as an input into another njit function so I get a casting warning (and I also do unnecessary casts in-between the njit functions)

Is there any way to force numba to give me np.uint32 as a result instead?

--- EDIT ---

This is the best I've managed to do myself, however I refuse to believe this is the best implementation out there:

# we manually define a return record and pass it as a parameter
res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))

@nb.njit(sig, cache=False)
def test_fn(a:np.uint32, b:np.uint32, res: res_type):
    res['res'] = a * b

# Call with Python ints (Numba should coerce based on signature)
res = np.recarray(1, dtype=res_type)[0]
res_py_in = test_fn(2, 10, res)
print(f"\nCalled with Python ints:")
print(f"Result value: {res['res']}")  # 20
print(f"Result type: {type(res['res'])}")  # <class 'numpy.uint32'>

--- EDIT 2 --- as @Nin17 correctly pointed out actually returning an int object is still about 3 times quicker when called from python context, so its better to just return a simple int and cast as needed.


Solution

  • Why don't you just return np.uint32(a*b):

    @nb.njit(nb.uint32(nb.uint32, nb.uint32))
    def func(a, b):
        return np.uint32(a * b)
    

    It is faster and more readable than the other solutions:

    import numba as nb
    import numpy as np
    
    @nb.njit(nb.types.Array(nb.uint32, 0, "C")(nb.uint32, nb.uint32))
    def test_fn(a, b):
        res = np.empty((), dtype=np.uint32)
        res[...] = a * b
        return res
    
    res_type = np.dtype([('res', np.uint32)])
    sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))
    
    @nb.njit(sig)
    def test_fn2(a, b, out):
        out['res'] = a * b
    
    res = np.recarray(1, dtype=res_type)[0]
    test_fn2(np.uint32(2), np.uint32(10), res)
    
    a = np.uint32(2)
    b = np.uint32(10)
    
    
    %timeit test_fn(a, b)
    %timeit test_fn2(a, b, res)
    %timeit func(a, b)
    

    Output:

    339 ns ± 4.67 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
    426 ns ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
    126 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
    
    N = int(1e7)
    @nb.njit
    def _test_fn(a, b):
        out = np.empty((N,), dtype=np.uint32)
        for i in range(N):
            out[i] = test_fn(a, b).item()
        return out
    
    @nb.njit
    def _test_fn2(a, b, res):
        out = np.empty((N,), dtype=np.uint32)
        for i in range(N):
            test_fn2(a, b, res)
            out[i] = res['res']
        return out
    
    @nb.njit
    def _func(a, b):
        out = np.empty((N,), dtype=np.uint32)
        for i in range(N):
            out[i] = func(a, b)
        return out
    
    _test_fn(a, b)
    _test_fn2(a, b, res)
    _func(a, b)
    
    %timeit _test_fn(a, b)
    %timeit _test_fn2(a, b, res)
    %timeit _func(a, b)
    

    Output:

    254 ms ± 508 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
    3.44 ms ± 40.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
    3.37 ms ± 19.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)