I have a numpy structured array and pass one element in it to a function as below.
from numba import njit
import numpy as np
dtype = np.dtype([
("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4"),
])
a = np.array([(1, b"24q1", 1.0)], dtype=dtype)
@njit
def upsert_numba(a, sid, qtrnm, val):
a[1] = qtrnm
a[2] = val
#i = 0
#a[i+1] = qtrnm
#a[i+2] = val
return a
x = (1, b"24q2", 3.0)
print(upsert_numba(a[0].copy(), *x))
The above code works without problems. But if the updating is through the codes commented out, i.e. i=0;a[i+1]=qtrnm;a[i+2]=val, numba gives the following error.
No implementation of function Function(<built-in function setitem>) found for signature:
>>> setitem(Record(id[type=int32;offset=0],qtrnm0[type=[char x 4];offset=4],qtr0[type=float32;offset=8];12;False), int64, readonly bytes(uint8, 1d, C))
It seems like indexing is only allowed by a constant, which can be an integer or a CharSeq, known at compile time, but not an expresson on the constant, which is also known at compile time though. May I know what is happening under the hood?
I have tried other constant as index like "j=i; a[j]", which also works. But unsurprisingly "j=i+1;a[j]" fails.
Consider the following functions.
from numba import njit
@njit
def func():
i = 777
t = i
return t
@njit
def func2():
i = 776
t = i + 1
return t
You can check how each variable's type is inferred using the following method.
func()
func.inspect_types()
This is the key lines:
# i = const(int, 777) :: Literal[int](777)
# t = i :: Literal[int](777)
The part after ::
is the type of the variable.
This indicates that both i
and t
are of integer literal type.
Next, for func2
:
func2()
func2.inspect_types()
# i = const(int, 776) :: Literal[int](776)
# t = i + $const10.2 :: int64
Compared to func
, you can see that t
is inferred as int64
rather than an integer literal type.
This means, numba performs type inference on the code before optimization.
This is a reasonable choice. Typed code is required for optimization, but type inference is required to generate typed code. So first type inference is performed on Python bytecode, and then optimization is performed based on the inferred types. For more accurate and detailed information on this flow, please refer to the official documentation.
In summary, you need a constant variable at the Python bytecode phase.
As an additional note, numba does not support indexing records with non-literal variables. However, it is somehow possible by explicitly defining the mapping via overloading.
from operator import setitem
import numpy as np
from numba import njit, types
from numba.core.extending import overload
a_dtype = np.dtype([("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4")])
@overload(setitem)
def setitem_overload_for_a(a, index, value):
if getattr(a, "dtype", None) != a_dtype:
return None
if isinstance(value, (types.Integer, types.Float)):
def numeric_impl(a, index, value):
# You need to map these indexes correctly according to the dtype.
if index == 0:
a[0] = value
elif index == 2:
a[2] = value
else:
raise ValueError()
return numeric_impl
elif isinstance(value, (types.Bytes, types.CharSeq)):
def bytes_impl(a, index, value):
if index == 1:
a[1] = value
else:
raise ValueError()
return bytes_impl
else:
raise TypeError(f"Unsupported type: {index=}, {value=}, {a.dtype=}")
@njit
def upsert_numba(a, sid, qtrnm, val):
i = 0
a[i + 1] = qtrnm
a[i + 2] = val
return a
x = (1, b"24q2", 3.0)
a = np.array([(1, b"24q1", 1.0)], dtype=a_dtype)
print(upsert_numba(a[0].copy(), *x)) # (1, b'24q2', 3.)
Note that this is an ad hoc strategy that requires you to hardcode setitem for each record type, and may not work in some cases. That said, it should work unless you're doing something very tricky.