pythonnumpyctypesnumba

Numba cfunc factory with numpy arrays


I want to have a factory method that calls a cfunc using numpy arrays. I am trying to pass the numpy arrays by using a ctype pointer. Since my original code is rather complicated I have made a simple case that reproduces the error I am getting:

import numpy as np
from numba import cfunc, carray
from numba.types import intc, CPointer, float64
import ctypes

class SimpleExample():
    def __init__(self,array):
        self.array = array
        self.n = array.size

    # Define the C signature: double func(double *input_array, int n)
    sig = float64(CPointer(float64), intc)
    sig2 = float64()

    @staticmethod
    @cfunc(sig)
    def sum_array(ptr, n):
        arr = carray(ptr, n)  # Convert raw pointer to NumPy array
        return np.sum(arr)

    def sum_factory(self):
        arr = self.array
        size = self.n
        sum_fun = type(self).sum_array
        arr_ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
        @cfunc(sig2)
        def sum_array_2():
            return sum_fun(arr_ptr,size)
        return sum_array_2.address


a = np.array([1.0, 2.0, 3.0, 4.0], dtype=np.float64)
example = SimpleExample(a)

# Get pointer to factory-generated function
func_ptr = example.sum_factory()

# Create ctypes function pointer with no args (signature: double func())
cfunc_type = ctypes.CFUNCTYPE(ctypes.c_double)
cfunc_instance = cfunc_type(func_ptr)

print("Sum from SimplerExample sum_factory:", cfunc_instance())

This should produce the error:

KeyError                                  Traceback (most recent call last)
  File ".../numba/core/typing/bufproto.py", line 56, in decode_pep3118_format
    return _pep3118_scalar_map[fmt.lstrip('=')]
KeyError: '&<d'

During handling of the above exception, another exception occurred:

NumbaValueError                           Traceback (most recent call last)
  File "<ipython-input>", line 36, in <module>
    func_ptr = example.sum_factory()
  File "<ipython-input>", line 27, in sum_factory
    arr_ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double))
  ...
  File ".../numba/core/typing/bufproto.py", line 58, in decode_pep3118_format
    raise NumbaValueError("unsupported PEP 3118 format %r" % (fmt,))

NumbaValueError: Failed in nopython mode pipeline (step: nopython frontend)
Untyped global name 'arr_ptr': unsupported PEP 3118 format '&<d'
During: Pass nopython_type_inference

Solution

  • This question is similar to this one. The solution is to use the .ctypes attribute of the numpy array instead of using arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double)). The code becomes:

       def sum_factory(self):
            arr = self.array
            size = self.n
            sum_fun = type(self).sum_array
            # arr_ptr = arr.ctypes.data_as(ctypes.POINTER(ctypes.c_double)) !!! Not needed
            @cfunc(sig2)
            def sum_array_2():
                return sum_fun(arr.ctypes ,size) # Use arr.ctypes instead
            return sum_array_2.address
    

    Note, you probably also need to change your signatures or the line size = self.n to size = np.intc(self.n) so that this will work.