pythonnumbagslllvmlite

How to efficiently wrap GSL functions with structs and pointers using Numba?


I'm trying to wrap some GSL (GNU Scientific Library) functions using Numba. I'd like to avoid C glue code wrappers and to be able to cache and extend the wrapped functions.

Here’s a simplified structure of the site-package:

numba_gsl/
├── numba_gsl/
│   ├── gsl_integration.c     # Minimal C wrapper for GSL integration
│   ├── integration.py        # Python-side ctypes + numba wrapper
├── setup.py                  # build script

gsl_integration.c exposes a GSL function:

// gsl_integration.c (must be compiled)
#include <gsl/gsl_integration.h>
#include <stdint.h>

typedef double (*func_ptr)(double x, void* params);

double qag(
        func_ptr f,
        void* user_data,
        int which,
        double a,
        double b,
        double epsabs,
        double epsrel,
        int limit,
        int key) {
    gsl_integration_workspace* w = gsl_integration_workspace_alloc(limit);
    double result, error;

    gsl_function gsl_func;
    gsl_func.function = f;
    gsl_func.params = user_data;

    gsl_integration_qag(&gsl_func, a, b, epsabs, epsrel, limit, key, w, &result, &error);
    gsl_integration_workspace_free(w);

    return result;
}

In integration.py I load the compiled shared library with ctypes, define argument types, and expose a jitted function that calls into it by passing function pointers obtained via Numba's cfunc.

# integration.py
import ctypes as ct
from numba import njit

def get_extension_path(lib_name: str) -> str:
    search_path = Path(__file__).parent
    pattern = f"*{lib_name}.*"
    matches = search_path.glob(pattern)
    try:
        return str(next(matches))
    except StopIteration:
        return None

# Load the shared GSL integration library (.so)
_lib_path = get_extension_path('gsl_integration')
_lib = ct.CDLL(_lib_path)

# Define ctypes function prototype for GSL wrapper:
# double qag(void *f, void *params, double a, double b, double epsabs, double epsrel, int limit, int key)
qag_c = _lib.qag
qag_c.argtypes = [ct.c_void_p, ct.c_void_p,
                  ct.c_double, ct.c_double,
                  ct.c_double, ct.c_double,
                  ct.c_int, ct.c_int]
qag_c.restype = ct.c_double

@njit
def qag(func_ptr: int,
        a: float,
        b: float,
        epsabs: float = 1.49e-08,
        epsrel: float = 1.49e-08,
        limit: int = 50,
        key: int = 1,
        params_ptr: int = 0) -> float:
    """GSL integration wrapper with finite intervalls."""
    return qag_c(func_ptr, params_ptr, a, b, epsabs, epsrel, limit, key)

Here is a build-script:

# setup.py
from setuptools import setup, Extension, find_packages

ext = Extension(
    "numba_gsl.gsl_integration",
    sources=["numba_gsl/gsl_integration.c"],
    libraries=["gsl", "gslcblas"],
    extra_compile_args=["-O3"]
)

setup(
    name="numba_gsl",
    version="0.1.0",
    description="Numba-friendly wrappers for GSL routines",
    author="No Name",
    packages=find_packages(),
    ext_modules=[ext],
    install_requires=["numba", "llvmlite"],
    python_requires=">=3.12",
)

And an example:

# example.py
import math
from numba import cfunc, types
from numba_gsl.integration import qag
from scipy.integrate import quad

# Define the integrand as a Numba cfunc with the proper signature:
@cfunc(types.float64(types.float64, types.voidptr))
def sin_over_x(x, _):
    return math.sin(x) / x if x != 0.0 else 1.0

def py_func(x):
    return math.sin(x) / x if x != 0 else 1.0

func_ptr = sin_over_x.address
qag_res = qag(func_ptr, a=1e-8, b=3.14)
scipy_res = quad(py_func, a=1e-8, b=3.14)[0]

print("numba_gsl quad result:", qag_res)
print("scipy     quad result:", scipy_res)
# numba_gsl quad result: 1.8519366381423115
# scipy     quad result: 1.8519366381423115

Is there a (better) way to wrap complex GSL structs and pass them along with function pointers to GSL routines with Numba, perhaps using:


Solution

  • Wraping GSL-functions in Numba

    There are basically two seperate questions to solve.

    How to call a function from dll/so while beeing able to cache it?

    This can be achieved with llvmlite.binding.load_library_permanently and numba.types.ExternalFunction as shown in the example below. For PyCapsule I posted an example on Numba discurse.

    How to get the arguments represented as a tuple of anything Numba supports to gsl_function gsl_func

    This was the part I invested most of the time on since it wasn't the first time I tried to solve it. The basic point is, that a Tuple in Numba is lowered the same way as a struct in C. Although this should be done with care when calling an external function, because the way a C-Compiler makes LLVM-IR in structs is generally platform depended. For simple structs like {double*, void*} this should work the same on most platforms.

    Some intrinsics

    from numba import types,typeof,njit,cfunc
    from numba.extending import intrinsic
    from numba.core import cgutils
    import numba as nb
    
    @intrinsic
    def val_to_ptr(typingctx, data):
        """
        Generates as Pointer on the stack.
        Useful for passing values by reference to functions
        """
        def impl(context, builder, signature, args):
            ptr = cgutils.alloca_once_value(builder,args[0])
            return ptr
        sig = types.CPointer(typeof(data).instance_type)(typeof(data).instance_type)
        return sig, impl
    
    @intrinsic
    def ptr_to_val(typingctx, data):
        """
        Gets the value of some pointer, generated with val_to_ptr
        Useful for passing values by reference to functions
        """
        def impl(context, builder, signature, args):
            val = builder.load(args[0])
            return val
        sig = data.dtype(types.CPointer(data.dtype))
        return sig, impl
    
    @intrinsic
    def sizeof_tuple(typingctx, val_type):
        """
        Get the size of a LLVM Tuple in Numba
        """
        if not isinstance(args, types.Tuple):
            raise TypeError("Argument must be a tuple")
            
        def codegen(context, builder, sig, args):
            # Get pointer to the struct type:
            ty = context.get_value_type(val_type)
            ptr = cgutils.alloca_once(builder, ty)
    
            # GEP to get pointer to next element:
            gep = builder.gep(ptr, [context.get_constant(types.int32, 1)])
    
            # Cast to intptr:
            intptr_t = context.get_value_type(types.intp)
            base = builder.ptrtoint(ptr, intptr_t)
            next_ = builder.ptrtoint(gep, intptr_t)
    
            # Subtract:
            size = builder.sub(next_, base)
            return size
    
        sig = types.intp(val_type)
        return sig, codegen
    
    @intrinsic
    def cast_tuple_to_voidptr(typingctx, args):
        """
        Casts a llvm struct (Numba Tuple) to a void pointer
        """
        if not isinstance(args, types.Tuple):
            raise TypeError("Argument must be a tuple")
        
        #Set return Type
        sig = types.voidptr(args)
        def codegen(context, builder, signature, args):
            [val] = args
            ty = signature.args[0]
    
            tup_ptr = cgutils.alloca_once_value(builder, val)
            
            voidptr_ty = context.get_value_type(types.voidptr)
            cast_ptr = builder.bitcast(tup_ptr, voidptr_ty)
    
            return cast_ptr
    
        return sig, codegen
    
    def gen_cfunc_and_signature(func,args,cache=False):
        """
        Generates a c-func with signature double(double, *some_tuple) of a Numba or Python function
        
        As a second output a Numba function is generated, which generates the lowering result of a Tuple
        """
        if not callable(func) and not isinstance(func,nb.core.registry.CPUDispatcher):
            raise TypeError("The first argument func must be a function ")
    
        if not isinstance(args, tuple):
            raise TypeError("args must be a tuple")
         
        if isinstance(func,nb.core.registry.CPUDispatcher):
             func = func.py_func
        
        nb_func =  nb.njit(types.double(types.double,nb.typeof(args)),inline="always",cache = cache)(func)
        sig = types.double(types.double,types.CPointer(nb.typeof(args)))
        
        @cfunc(sig,error_model="numpy",cache = True)
        def c_func(x,args_ptr):
            return nb_func(x,args_ptr[0])
        return c_func,sig
    
    def gen_cfunc_with_args(func,args,cache=False):
        """
        Generates a c-func with signature double(double, *some_tuple) of a Numba or Python function
        Useful eg. for scipy.integrate.quad low level interface
        
        As a second output a Numba function is generated, which generates the lowering result of a Tuple
        """
        if not callable(func) and not isinstance(func,nb.core.registry.CPUDispatcher):
            raise TypeError("The first argument func must be a function ")
    
        if not isinstance(args, tuple):
            raise TypeError("args must be a tuple")
         
        if isinstance(func,nb.core.registry.CPUDispatcher):
             func = func.py_func
    
        @njit(cache = cache)
        def func_gen_lowered_struct(tup):
            size = sizeof_struct(tup)
            ptr  = cast_tuple_to_voidptr(tup)
        
            view_to_tuple = nb.carray(ptr,(size,),dtype=np.uint8)
            lowered =np.copy(view_to_tuple)
            return lowered
        
        
        nb_func =  nb.njit(types.double(types.double,nb.typeof(args)),inline="always",cache = cache)(func)
    
        @cfunc(types.double(types.double,types.CPointer(nb.typeof(args))),error_model="numpy",cache = cache)
        def c_func(x,args_ptr):
            return nb_func(x,args_ptr[0])
        return c_func,func_gen_lowered_struct
    
    
    @njit(inline="always")
    def create_gsl_function(func_address, args):
        """Generates the struct (function_adress,pointer_to args)
        This isn't save, because it assumes Numba and the C-Compiler have the same layout on struct(double*,double*)
        """
        args_p = cast_tuple_to_voidptr(args) #cast the args tuple to voidptr
        return cast_tuple_to_voidptr((func_address,args_p))
    

    The acutal wrapper

    from llvmlite import binding
    from numba import types
    import numba as nb
    import gsl_intrinsics as gsl_intr
    binding.load_library_permanently(r'path_to_gsl.dll')
    
    #Shotcuts
    dble   = types.double
    dble_p = types.CPointer(dble)
    intc    = types.intc
    size_t = types.size_t
    void_p = types.voidptr
    
    #Binding functons###############
    c_func_name = 'gsl_integration_qag'
    c_sig = intc(void_p,dble,dble,
             dble,dble,size_t,
             intc,void_p,
             dble_p,dble_p)
    
    c_gsl_integration_qag = types.ExternalFunction(c_func_name, c_sig)
    
    c_func_name = 'gsl_integration_workspace_alloc'
    c_sig = void_p(size_t,)
    c_gsl_integration_workspace_alloc = types.ExternalFunction(c_func_name, c_sig)
    
    c_func_name = 'gsl_integration_workspace_free'
    c_sig = types.none(void_p,)
    c_gsl_integration_workspace_free = types.ExternalFunction(c_func_name, c_sig)
    #################################
    
    
    
    @nb.njit(cache=True)
    def nb_gsl_qag(func_address,args, a, b, epsabs, epsrel, limit, key):
        
        #generate pointer to args
        args_p = gsl_intr.cast_tuple_to_voidptr(args) #cast the args tuple to voidptr
        
        #generate the gsl function
        gsl_func = gsl_intr.cast_tuple_to_voidptr((func_address,args_p))
    
        #allocate GSL integration workspace
        workspace = c_gsl_integration_workspace_alloc(limit)
        
        #Allocate result and error on the stack
        result_p = gsl_intr.val_to_ptr(nb.double(0))
        error_p  = gsl_intr.val_to_ptr(nb.double(0))
    
        #call the fuction
        ret = c_gsl_integration_qag (gsl_func,a,b,
                                     epsabs, epsrel, limit,
                                     key, workspace,
                                     result_p, error_p)
        
        #free GSL integration workspace
        c_gsl_integration_workspace_free(workspace)
        
        return gsl_intr.ptr_to_val(result_p),gsl_intr.ptr_to_val(error_p)
    

    Using the wrapped function

    import numba as nb
    from numba import types
    import numpy as np
    import gsl_intrinsics as gsl_intr
    import gsl as gsl
    import math
    
    
    #Generate a bit more complicated input a
    args = (5.,np.ones((3,3)),8.)
    
    @nb.njit()
    def sin_over_x(x, args):
        return math.sin(x) / x +args[2] if x != 0.0 else 1.0
    
    func, sig = gsl_intr.gen_cfunc_and_signature(sin_over_x,args,cache=True) #-> simpler but caching is not workig 
    
    #with explicit hand written types caching is working
    inside_tuple = (types.float64,types.Array(types.float64,2,'C'),types.float64)
    @nb.cfunc(types.double(types.double,types.CPointer(types.Tuple(inside_tuple))),cache=True)
    def func(x,args_p):
        return sin_over_x(x, args_p[0])
    
    a=1e-8
    b=3.14
    epsabs= 1.49e-08
    epsrel= 1.49e-08
    limit= 50
    key = 1
    
    
    print(gsl.nb_gsl_qag(func.address,args, a, b, epsabs, epsrel, limit, key))
    
    print("Cache hits nb_gsl_qag:", gsl.nb_gsl_qag.stats.cache_hits)
    print("Cache misses nb_gsl_qag:", gsl.nb_gsl_qag.stats.cache_misses)
    
    print("")
    print("Cache hits func:", func.cache_hits)