pythonnumba

Numba indexing on Record type (structured array in numpy)


I have a numpy structured array and pass one element in it to a function as below.

from numba import njit
import numpy as np

dtype = np.dtype([
    ("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4"),
])
a = np.array([(1, b"24q1", 1.0)], dtype=dtype)

@njit
def upsert_numba(a, sid, qtrnm, val):
    a[1] = qtrnm
    a[2] = val
    #i = 0
    #a[i+1] = qtrnm
    #a[i+2] = val
    return a

x = (1, b"24q2", 3.0)
print(upsert_numba(a[0].copy(), *x))

The above code works without problems. But if the updating is through the codes commented out, i.e. i=0;a[i+1]=qtrnm;a[i+2]=val, numba gives the following error.

No implementation of function Function(<built-in function setitem>) found for signature:

 >>> setitem(Record(id[type=int32;offset=0],qtrnm0[type=[char x 4];offset=4],qtr0[type=float32;offset=8];12;False), int64, readonly bytes(uint8, 1d, C))

It seems like indexing is only allowed by a constant, which can be an integer or a CharSeq, known at compile time, but not an expresson on the constant, which is also known at compile time though. May I know what is happening under the hood?

I have tried other constant as index like "j=i; a[j]", which also works. But unsurprisingly "j=i+1;a[j]" fails.


Solution

  • Consider the following functions.

    from numba import njit
    
    
    @njit
    def func():
        i = 777
        t = i
        return t
    
    
    @njit
    def func2():
        i = 776
        t = i + 1
        return t
    

    You can check how each variable's type is inferred using the following method.

    func()
    func.inspect_types()
    

    This is the key lines:

        #   i = const(int, 777)  :: Literal[int](777)
        #   t = i  :: Literal[int](777)
    

    The part after :: is the type of the variable. This indicates that both i and t are of integer literal type.

    Next, for func2:

    func2()
    func2.inspect_types()
    
        #   i = const(int, 776)  :: Literal[int](776)
        #   t = i + $const10.2  :: int64
    

    Compared to func, you can see that t is inferred as int64 rather than an integer literal type. This means, numba performs type inference on the code before optimization.

    This is a reasonable choice. Typed code is required for optimization, but type inference is required to generate typed code. So first type inference is performed on Python bytecode, and then optimization is performed based on the inferred types. For more accurate and detailed information on this flow, please refer to the official documentation.

    In summary, you need a constant variable at the Python bytecode phase.


    As an additional note, numba does not support indexing records with non-literal variables. However, it is somehow possible by explicitly defining the mapping via overloading.

    from operator import setitem
    
    import numpy as np
    from numba import njit, types
    from numba.core.extending import overload
    
    a_dtype = np.dtype([("id", "i4"), ("qtrnm0", "S4"), ("qtr0", "f4")])
    
    
    @overload(setitem)
    def setitem_overload_for_a(a, index, value):
        if getattr(a, "dtype", None) != a_dtype:
            return None
    
        if isinstance(value, (types.Integer, types.Float)):
            def numeric_impl(a, index, value):
                # You need to map these indexes correctly according to the dtype.
                if index == 0:
                    a[0] = value
                elif index == 2:
                    a[2] = value
                else:
                    raise ValueError()
    
            return numeric_impl
        elif isinstance(value, (types.Bytes, types.CharSeq)):
            def bytes_impl(a, index, value):
                if index == 1:
                    a[1] = value
                else:
                    raise ValueError()
    
            return bytes_impl
        else:
            raise TypeError(f"Unsupported type: {index=}, {value=}, {a.dtype=}")
    
    
    @njit
    def upsert_numba(a, sid, qtrnm, val):
        i = 0
        a[i + 1] = qtrnm
        a[i + 2] = val
        return a
    
    
    x = (1, b"24q2", 3.0)
    a = np.array([(1, b"24q1", 1.0)], dtype=a_dtype)
    print(upsert_numba(a[0].copy(), *x))  # (1, b'24q2', 3.)
    

    Note that this is an ad hoc strategy that requires you to hardcode setitem for each record type, and may not work in some cases. That said, it should work unless you're doing something very tricky.