I would like to dispatch on the type of the second argument in a function in numba and fail in doing so.
If it is an integer then a vector should be returned, if it is itself an array of integers, then a matrix should be returned.
The first code does not work
@njit
def test_dispatch(X, indices):
if isinstance(indices, nb.int64):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
elif isinstance(indices, nb.int64[:]):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
while the second one, with an else
, does.
@njit
def test_dispatch(X, indices):
if isinstance(indices, nb.int64):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
else:
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
I guess that the problem is the type declaration via nb.int64[:]
but I don't get it to work in any other way.
Do you have an idea?
Note that this question applies to numba>=0.59
.
generated_jit
is deprecated in earlier versions and actually removed from
versions 0.59 on.
You should not use isinstance
in a JIT function like this, but instead use @overload
(@generated_jit
was the old obsolete way to do that) which is specifically made for this purpose. This enables Numba to generate the code faster since only a part of the function is compiled for each case rather than all the case for each specialization. Moreover, isinstance
is experimental as specified by Numba in a warning when your first code is executed (warning are reported for users to read them ;) ).
@overload
methodStarting from Numba 0.59, overload
must be used instead:
import numba as nb
import numpy as np
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
# Pure-python fallback implementation
def test_dispatch_impl(X, indices):
if isinstance(indices, (int, np.integer)):
return test_dispatch_scalar(X, indices)
elif isinstance(indices, np.ndarray) and indices.ndim == 1 and np.issubdtype(indices.dtype, np.integer):
return test_dispatch_vector(X, indices)
else:
assert False # Unsupported
# Numba-specific overload
@nb.extending.overload(test_dispatch_impl)
def test_dispatch_impl_overload(X, indices):
if isinstance(indices, nb.types.Integer):
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and isinstance(indices.dtype, nb.types.Integer):
return test_dispatch_vector
else:
assert False # Unsupported
@nb.njit
def test_dispatch(X, indices):
return test_dispatch_impl(X, indices)
Here is an example reasoning about generic types:
import numba as nb
import numpy as np
@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
if isinstance(indices, nb.types.Integer):
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and isinstance(indices.dtype, nb.types.Integer):
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
return test_dispatch_vector
else:
assert False # Unsupported
Here is an example reasoning about specific types:
import numba as nb
import numpy as np
@nb.generated_jit(nopython=True)
def test_dispatch(X, indices):
if indices == nb.types.int64:
def test_dispatch_scalar(X, indices):
ref_pos = np.empty(3, np.float64)
ref_pos[:] = X[:, indices]
return ref_pos
return test_dispatch_scalar
elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and indices.dtype == nb.types.int64:
def test_dispatch_vector(X, indices):
ref_pos = np.empty((3, len(indices)), np.float64)
ref_pos[:, :] = X[:, indices]
return ref_pos
return test_dispatch_vector
else:
assert False # Unsupported
Requesting specifically 64-bit integers can be a bit too restrictive so I advise you to mix generic type tests and specific ones. For the same reason, you should avoid testing directly if arrays are of a specific type, simply because they can often be contiguous or not or can contain item types compatible with your function.
Note that generic JIT functions are meant to generate functions which are compiled separately regarding the target input type (not the values).