pythonnumpymultidimensional-arraynumbaindices

Numba multi-dimensional indices are not supported


I am running a code to simulate an economic model. The code uses numpy and performs without error. I am trying to speed up the performance by including the "njit()" decorator from numba in a specific function. To my surprise the code now produces an error that should be not appearing. When I am applying the function c[c<2000] = 2000 a multidimensional index error appears. The wierd part of this is that I perform this operation in two different parts of the code, and in the first one it works without any error, but in the second I get a multidimensional index problem. The code is the following. The error comes from the one before the last one line of the code.

@njit()
def get_utility(x1,x1_new,x2,b,b1,e,j,period,param_g):
    if j[1] == 0:  # the individual does not study and no max is needed.
        w = wage(x1_new,x2)*(j[2]/2)   # adjust wages for labor supply decision
        w_vis = np.repeat(w,np.shape(b)[0]*np.shape(e)[0])
        b_vis = numba_tile_new(b,np.shape(w)[0]*np.shape(e)[0])
        e_vis = np.repeat(e,np.shape(w)[0]*np.shape(b)[0])
        
        c = (w_vis-(1+r)*b_vis+e_vis+repayment(b_vis))
        c[c<2000]  = 2000
        u = get_power_utility(c)
        # Create indicator for which choice
        
        pg = get_param_g(j,param_g).astype("float64")
            
        # Include a constant to g()
        
        # Create x1 with polynomials. 
        
        x1_poli = get_x1_poli(x1)
        
        # get g function
        g = x1_poli@pg
        
        payoff =  u+ 0.1*g
        
        return payoff[...,np.newaxis]

    else: 
        h = fin_help(x1_new,x2,j,period)
        h_vis = np.repeat(h,np.shape(b)[0]*np.shape(e)[0])
        
        w = wage(x1_new,x2)*(j[2]/2)   # adjust wages for labor supply decision
        w_vis = np.repeat(w,np.shape(b)[0]*np.shape(e)[0])
        
        b_vis = numba_tile_new(b,np.shape(h)[0]*np.shape(e)[0])
        e_vis = np.repeat(e,np.shape(h)[0]*np.shape(b)[0])
        c = (h_vis-(1+r)*b_vis-tuition(j)+e_vis+ w_vis)
        c =c[...,np.newaxis] + b1
        
        c[c<2000]  = 2000
        
        return c

The error produced is:

`TypingError: No implementation of function Function() found for signature:

setitem(array(float64, 2d, C), array(bool, 2d, C), Literalint)

There are 16 candidate implementations: - Of which 14 did not match due to: Overload of function 'setitem': File: : Line N/A. With argument(s): '(array(float64, 2d, C), array(bool, 2d, C), int64)': No match. - Of which 2 did not match due to: Overload in function 'SetItemBuffer.generic': File: numba\core\typing\arraydecl.py: Line 219. With argument(s): '(array(float64, 2d, C), array(bool, 2d, C), int64)': Rejected as the implementation raised a specific error: NumbaTypeError: Multi-dimensional indices are not supported. raised from C:\Users\Sergi\anaconda3\Lib\site-packages\numba\core\typing\arraydecl.py:89`

As I mention this is extremely wierd since if I uncomment the one before the last line, the code performs without errors even though I am still using c[c<2000] = 2000 elsewhere.

I am not including all the code to replicate the error since it is a very long document, but I am very sure the error just comes from this part.

Thanks in advance!


Solution

  • numba and numpy are 2 completely different beasts. You are trying to pull numpy functionality into a numba decorated function, which fundamentally changes how this function is called. The njit() decorator forces this function to be compiled "just in time" to be run, which means you lose some of your numpy functionality.

    So what's the problem? Numba does not allow advanced indexing on more than 1 dimension. That is exactly what you are trying to do in your c[c < 2000] = 2000 statement. You can still carry out your operation, but you must first flatten the array (NOTE: using ravel does not work here). For example, in your code you must:

    orig_shape = c.shape
    c = c.flatten()
    c[c < 2000] = 2000
    c = c.reshape(orig_shape)