pythonnumbajit

Strange interaction between reshaping and type casting in Numba


I have noticed that when I pass a 2d array of 0s and 1s into a Numba njit function, reshape it, and then cast it to np.int32 or numba.int32, the resulting array when printed is different.

Here is example code:

import numpy as np
from numba import njit

array_2d = np.array([[0, 1, 0, 1, 1, 0, 0, 1],
                     [0, 1, 1, 0, 1, 1, 0, 0],
                     [0, 1, 1, 0, 1, 1, 0, 0]]).T

num_cols = array_2d.shape[1]
num_rows = array_2d.shape[0]

@njit
def f(array, num_rows, num_cols):
    pairs = array.reshape(num_rows // 2, 2, num_cols)

    pairs_cast = pairs.astype(numba.int32)

    return pairs, pairs_cast

pairs, pairs_cast = f(array_2d, num_rows, num_cols)

print("Pairs:")
print(pairs)
print("\nPairs cast to int32:")
print(pairs_cast)

The output is:

Pairs:
[[[0 0 0]
  [1 1 1]]

 [[0 1 1]
  [1 0 0]]

 [[1 1 1]
  [0 1 1]]

 [[0 0 0]
  [1 0 0]]]

Pairs cast to int32:
[[[0 0 0]
  [1 1 1]]

 [[1 1 1]
  [0 1 1]]

 [[0 1 1]
  [0 0 0]]

 [[1 0 0]
  [1 0 0]]]

Would be curious to know what's going on here.


Solution

  • As mentioned in the comments, reshape (in numba) currently only supports contiguous arrays, but it seems it isn't triggering the error here. You can see the error if you convert your array to a non-contiguous one within your numba compiled code:

    import numpy as np
    import numba as nb
    
    @nb.njit
    def func(a):
        return a.transpose(1, 0).reshape(-1)
    
    func(np.ones((10, 10)))
    

    raises:

    TypingError: Failed in nopython mode pipeline (step: nopython frontend)
    - Resolution failure for literal arguments:
    reshape() supports contiguous array only
    - Resolution failure for non-literal arguments:
    reshape() supports contiguous array only