I find this behavior quite counter-intuitive although I suppose there is a reason for it - numba automatically converts my numpy integer types directly into a python int:
import numba as nb
import numpy as np
print(f"Numba version: {nb.__version__}") # 0.59.0
print(f"NumPy version: {np.__version__}") # 1.23.5
# Explicitly define the signature
sig = nb.uint32(nb.uint32, nb.uint32)
@nb.njit(sig, cache=False)
def test_fn(a, b):
return a * b
res = test_fn(2, 10)
print(f"Result value: {res}") # returns 20
print(f"Result type: {type(res)}") # returns <class 'int'>
This is an issue as I'm using the return as an input into another njit function so I get a casting warning (and I also do unnecessary casts in-between the njit functions)
Is there any way to force numba to give me np.uint32
as a result instead?
--- EDIT ---
This is the best I've managed to do myself, however I refuse to believe this is the best implementation out there:
# we manually define a return record and pass it as a parameter
res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))
@nb.njit(sig, cache=False)
def test_fn(a:np.uint32, b:np.uint32, res: res_type):
res['res'] = a * b
# Call with Python ints (Numba should coerce based on signature)
res = np.recarray(1, dtype=res_type)[0]
res_py_in = test_fn(2, 10, res)
print(f"\nCalled with Python ints:")
print(f"Result value: {res['res']}") # 20
print(f"Result type: {type(res['res'])}") # <class 'numpy.uint32'>
--- EDIT 2 --- as @Nin17 correctly pointed out actually returning an int object is still about 3 times quicker when called from python context, so its better to just return a simple int and cast as needed.
Why don't you just return np.uint32(a*b)
:
@nb.njit(nb.uint32(nb.uint32, nb.uint32))
def func(a, b):
return np.uint32(a * b)
It is faster and more readable than the other solutions:
import numba as nb
import numpy as np
@nb.njit(nb.types.Array(nb.uint32, 0, "C")(nb.uint32, nb.uint32))
def test_fn(a, b):
res = np.empty((), dtype=np.uint32)
res[...] = a * b
return res
res_type = np.dtype([('res', np.uint32)])
sig = nb.void(nb.uint32, nb.uint32, nb.from_dtype(res_type))
@nb.njit(sig)
def test_fn2(a, b, out):
out['res'] = a * b
res = np.recarray(1, dtype=res_type)[0]
test_fn2(np.uint32(2), np.uint32(10), res)
a = np.uint32(2)
b = np.uint32(10)
%timeit test_fn(a, b)
%timeit test_fn2(a, b, res)
%timeit func(a, b)
Output:
339 ns ± 4.67 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
426 ns ± 1.01 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
126 ns ± 0.111 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)
N = int(1e7)
@nb.njit
def _test_fn(a, b):
out = np.empty((N,), dtype=np.uint32)
for i in range(N):
out[i] = test_fn(a, b).item()
return out
@nb.njit
def _test_fn2(a, b, res):
out = np.empty((N,), dtype=np.uint32)
for i in range(N):
test_fn2(a, b, res)
out[i] = res['res']
return out
@nb.njit
def _func(a, b):
out = np.empty((N,), dtype=np.uint32)
for i in range(N):
out[i] = func(a, b)
return out
_test_fn(a, b)
_test_fn2(a, b, res)
_func(a, b)
%timeit _test_fn(a, b)
%timeit _test_fn2(a, b, res)
%timeit _func(a, b)
Output:
254 ms ± 508 μs per loop (mean ± std. dev. of 7 runs, 1 loop each)
3.44 ms ± 40.4 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)
3.37 ms ± 19.9 μs per loop (mean ± std. dev. of 7 runs, 100 loops each)