I have noticed that when I pass a 2d array of 0s and 1s into a Numba njit function, reshape it, and then cast it to np.int32 or numba.int32, the resulting array when printed is different.
Here is example code:
import numpy as np
from numba import njit
array_2d = np.array([[0, 1, 0, 1, 1, 0, 0, 1],
[0, 1, 1, 0, 1, 1, 0, 0],
[0, 1, 1, 0, 1, 1, 0, 0]]).T
num_cols = array_2d.shape[1]
num_rows = array_2d.shape[0]
@njit
def f(array, num_rows, num_cols):
pairs = array.reshape(num_rows // 2, 2, num_cols)
pairs_cast = pairs.astype(numba.int32)
return pairs, pairs_cast
pairs, pairs_cast = f(array_2d, num_rows, num_cols)
print("Pairs:")
print(pairs)
print("\nPairs cast to int32:")
print(pairs_cast)
The output is:
Pairs:
[[[0 0 0]
[1 1 1]]
[[0 1 1]
[1 0 0]]
[[1 1 1]
[0 1 1]]
[[0 0 0]
[1 0 0]]]
Pairs cast to int32:
[[[0 0 0]
[1 1 1]]
[[1 1 1]
[0 1 1]]
[[0 1 1]
[0 0 0]]
[[1 0 0]
[1 0 0]]]
Would be curious to know what's going on here.
As mentioned in the comments, reshape (in numba) currently only supports contiguous arrays, but it seems it isn't triggering the error here. You can see the error if you convert your array to a non-contiguous one within your numba compiled code:
import numpy as np
import numba as nb
@nb.njit
def func(a):
return a.transpose(1, 0).reshape(-1)
func(np.ones((10, 10)))
raises:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
reshape() supports contiguous array only
- Resolution failure for non-literal arguments:
reshape() supports contiguous array only