pythontypesnumba

numba dispatch on type


I would like to dispatch on the type of the second argument in a function in numba and fail in doing so.

If it is an integer then a vector should be returned, if it is itself an array of integers, then a matrix should be returned.

The first code does not work

@njit
def test_dispatch(X, indices):
    if isinstance(indices, nb.int64):
        ref_pos = np.empty(3, np.float64)
        ref_pos[:] = X[:, indices]
        return ref_pos
    elif isinstance(indices, nb.int64[:]):
        ref_pos = np.empty((3, len(indices)), np.float64)
        ref_pos[:, :] = X[:, indices]
        return ref_pos

while the second one, with an else, does.

@njit
def test_dispatch(X, indices):
    if isinstance(indices, nb.int64):
        ref_pos = np.empty(3, np.float64)
        ref_pos[:] = X[:, indices]
        return ref_pos
    else:
        ref_pos = np.empty((3, len(indices)), np.float64)
        ref_pos[:, :] = X[:, indices]
        return ref_pos

I guess that the problem is the type declaration via nb.int64[:] but I don't get it to work in any other way. Do you have an idea?

Note that this question applies to numba>=0.59. generated_jit is deprecated in earlier versions and actually removed from versions 0.59 on.


Solution

  • You should not use isinstance in a JIT function like this, but instead use @overload (@generated_jit was the old obsolete way to do that) which is specifically made for this purpose. This enables Numba to generate the code faster since only a part of the function is compiled for each case rather than all the case for each specialization. Moreover, isinstance is experimental as specified by Numba in a warning when your first code is executed (warning are reported for users to read them ;) ).


    Using the new @overload method

    Starting from Numba 0.59, overload must be used instead:

    import numba as nb
    import numpy as np
    
    def test_dispatch_scalar(X, indices):
        ref_pos = np.empty(3, np.float64)
        ref_pos[:] = X[:, indices]
        return ref_pos
    
    def test_dispatch_vector(X, indices):
        ref_pos = np.empty((3, len(indices)), np.float64)
        ref_pos[:, :] = X[:, indices]
        return ref_pos
    
    # Pure-python fallback implementation
    def test_dispatch_impl(X, indices):
        if isinstance(indices, (int, np.integer)):
            return test_dispatch_scalar(X, indices)
        elif isinstance(indices, np.ndarray) and indices.ndim == 1 and np.issubdtype(indices.dtype, np.integer):
            return test_dispatch_vector(X, indices)
        else:
            assert False # Unsupported
    
    # Numba-specific overload
    @nb.extending.overload(test_dispatch_impl)
    def test_dispatch_impl_overload(X, indices):
        if isinstance(indices, nb.types.Integer):
            return test_dispatch_scalar
        elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and isinstance(indices.dtype, nb.types.Integer):
            return test_dispatch_vector
        else:
            assert False # Unsupported
    
    @nb.njit
    def test_dispatch(X, indices):
        return test_dispatch_impl(X, indices)
    

    Old deprecated solution

    Here is an example reasoning about generic types:

    import numba as nb
    import numpy as np
    
    @nb.generated_jit(nopython=True)
    def test_dispatch(X, indices):
        if isinstance(indices, nb.types.Integer):
            def test_dispatch_scalar(X, indices):
                ref_pos = np.empty(3, np.float64)
                ref_pos[:] = X[:, indices]
                return ref_pos
            return test_dispatch_scalar
        elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and isinstance(indices.dtype, nb.types.Integer):
            def test_dispatch_vector(X, indices):
                ref_pos = np.empty((3, len(indices)), np.float64)
                ref_pos[:, :] = X[:, indices]
                return ref_pos
            return test_dispatch_vector
        else:
            assert False # Unsupported
    

    Here is an example reasoning about specific types:

    import numba as nb
    import numpy as np
    
    @nb.generated_jit(nopython=True)
    def test_dispatch(X, indices):
        if indices == nb.types.int64:
            def test_dispatch_scalar(X, indices):
                ref_pos = np.empty(3, np.float64)
                ref_pos[:] = X[:, indices]
                return ref_pos
            return test_dispatch_scalar
        elif isinstance(indices, nb.types.Array) and indices.ndim == 1 and indices.dtype == nb.types.int64:
            def test_dispatch_vector(X, indices):
                ref_pos = np.empty((3, len(indices)), np.float64)
                ref_pos[:, :] = X[:, indices]
                return ref_pos
            return test_dispatch_vector
        else:
            assert False # Unsupported
    

    Requesting specifically 64-bit integers can be a bit too restrictive so I advise you to mix generic type tests and specific ones. For the same reason, you should avoid testing directly if arrays are of a specific type, simply because they can often be contiguous or not or can contain item types compatible with your function.

    Note that generic JIT functions are meant to generate functions which are compiled separately regarding the target input type (not the values).