I'm trying to wrap some GSL (GNU Scientific Library) functions using Numba. I'd like to avoid C glue code wrappers and to be able to cache and extend the wrapped functions.
Here’s a simplified structure of the site-package:
numba_gsl/
├── numba_gsl/
│ ├── gsl_integration.c # Minimal C wrapper for GSL integration
│ ├── integration.py # Python-side ctypes + numba wrapper
├── setup.py # build script
gsl_integration.c exposes a GSL function:
// gsl_integration.c (must be compiled)
#include <gsl/gsl_integration.h>
#include <stdint.h>
typedef double (*func_ptr)(double x, void* params);
double qag(
func_ptr f,
void* user_data,
int which,
double a,
double b,
double epsabs,
double epsrel,
int limit,
int key) {
gsl_integration_workspace* w = gsl_integration_workspace_alloc(limit);
double result, error;
gsl_function gsl_func;
gsl_func.function = f;
gsl_func.params = user_data;
gsl_integration_qag(&gsl_func, a, b, epsabs, epsrel, limit, key, w, &result, &error);
gsl_integration_workspace_free(w);
return result;
}
In integration.py I load the compiled shared library with ctypes, define argument types, and expose a jitted function that calls into it by passing function pointers obtained via Numba's cfunc.
# integration.py
import ctypes as ct
from numba import njit
def get_extension_path(lib_name: str) -> str:
search_path = Path(__file__).parent
pattern = f"*{lib_name}.*"
matches = search_path.glob(pattern)
try:
return str(next(matches))
except StopIteration:
return None
# Load the shared GSL integration library (.so)
_lib_path = get_extension_path('gsl_integration')
_lib = ct.CDLL(_lib_path)
# Define ctypes function prototype for GSL wrapper:
# double qag(void *f, void *params, double a, double b, double epsabs, double epsrel, int limit, int key)
qag_c = _lib.qag
qag_c.argtypes = [ct.c_void_p, ct.c_void_p,
ct.c_double, ct.c_double,
ct.c_double, ct.c_double,
ct.c_int, ct.c_int]
qag_c.restype = ct.c_double
@njit
def qag(func_ptr: int,
a: float,
b: float,
epsabs: float = 1.49e-08,
epsrel: float = 1.49e-08,
limit: int = 50,
key: int = 1,
params_ptr: int = 0) -> float:
"""GSL integration wrapper with finite intervalls."""
return qag_c(func_ptr, params_ptr, a, b, epsabs, epsrel, limit, key)
Here is a build-script:
# setup.py
from setuptools import setup, Extension, find_packages
ext = Extension(
"numba_gsl.gsl_integration",
sources=["numba_gsl/gsl_integration.c"],
libraries=["gsl", "gslcblas"],
extra_compile_args=["-O3"]
)
setup(
name="numba_gsl",
version="0.1.0",
description="Numba-friendly wrappers for GSL routines",
author="No Name",
packages=find_packages(),
ext_modules=[ext],
install_requires=["numba", "llvmlite"],
python_requires=">=3.12",
)
And an example:
# example.py
import math
from numba import cfunc, types
from numba_gsl.integration import qag
from scipy.integrate import quad
# Define the integrand as a Numba cfunc with the proper signature:
@cfunc(types.float64(types.float64, types.voidptr))
def sin_over_x(x, _):
return math.sin(x) / x if x != 0.0 else 1.0
def py_func(x):
return math.sin(x) / x if x != 0 else 1.0
func_ptr = sin_over_x.address
qag_res = qag(func_ptr, a=1e-8, b=3.14)
scipy_res = quad(py_func, a=1e-8, b=3.14)[0]
print("numba_gsl quad result:", qag_res)
print("scipy quad result:", scipy_res)
# numba_gsl quad result: 1.8519366381423115
# scipy quad result: 1.8519366381423115
Is there a (better) way to wrap complex GSL structs and pass them along with function pointers to GSL routines with Numba, perhaps using:
There are basically two seperate questions to solve.
How to call a function from dll/so while beeing able to cache it?
This can be achieved with llvmlite.binding.load_library_permanently
and numba.types.ExternalFunction
as shown in the example below.
For PyCapsule I posted an example on Numba discurse.
How to get the arguments represented as a tuple of anything Numba supports to gsl_function gsl_func
This was the part I invested most of the time on since it wasn't the first time I tried to solve it. The basic point is, that a Tuple in Numba is lowered the same way as a struct in C. Although this should be done with care when calling an external function, because the way a C-Compiler makes LLVM-IR in structs is generally platform depended. For simple structs like
{double*, void*}
this should work the same on most platforms.
Some intrinsics
from numba import types,typeof,njit,cfunc
from numba.extending import intrinsic
from numba.core import cgutils
import numba as nb
@intrinsic
def val_to_ptr(typingctx, data):
"""
Generates as Pointer on the stack.
Useful for passing values by reference to functions
"""
def impl(context, builder, signature, args):
ptr = cgutils.alloca_once_value(builder,args[0])
return ptr
sig = types.CPointer(typeof(data).instance_type)(typeof(data).instance_type)
return sig, impl
@intrinsic
def ptr_to_val(typingctx, data):
"""
Gets the value of some pointer, generated with val_to_ptr
Useful for passing values by reference to functions
"""
def impl(context, builder, signature, args):
val = builder.load(args[0])
return val
sig = data.dtype(types.CPointer(data.dtype))
return sig, impl
@intrinsic
def sizeof_tuple(typingctx, val_type):
"""
Get the size of a LLVM Tuple in Numba
"""
if not isinstance(args, types.Tuple):
raise TypeError("Argument must be a tuple")
def codegen(context, builder, sig, args):
# Get pointer to the struct type:
ty = context.get_value_type(val_type)
ptr = cgutils.alloca_once(builder, ty)
# GEP to get pointer to next element:
gep = builder.gep(ptr, [context.get_constant(types.int32, 1)])
# Cast to intptr:
intptr_t = context.get_value_type(types.intp)
base = builder.ptrtoint(ptr, intptr_t)
next_ = builder.ptrtoint(gep, intptr_t)
# Subtract:
size = builder.sub(next_, base)
return size
sig = types.intp(val_type)
return sig, codegen
@intrinsic
def cast_tuple_to_voidptr(typingctx, args):
"""
Casts a llvm struct (Numba Tuple) to a void pointer
"""
if not isinstance(args, types.Tuple):
raise TypeError("Argument must be a tuple")
#Set return Type
sig = types.voidptr(args)
def codegen(context, builder, signature, args):
[val] = args
ty = signature.args[0]
tup_ptr = cgutils.alloca_once_value(builder, val)
voidptr_ty = context.get_value_type(types.voidptr)
cast_ptr = builder.bitcast(tup_ptr, voidptr_ty)
return cast_ptr
return sig, codegen
def gen_cfunc_and_signature(func,args,cache=False):
"""
Generates a c-func with signature double(double, *some_tuple) of a Numba or Python function
As a second output a Numba function is generated, which generates the lowering result of a Tuple
"""
if not callable(func) and not isinstance(func,nb.core.registry.CPUDispatcher):
raise TypeError("The first argument func must be a function ")
if not isinstance(args, tuple):
raise TypeError("args must be a tuple")
if isinstance(func,nb.core.registry.CPUDispatcher):
func = func.py_func
nb_func = nb.njit(types.double(types.double,nb.typeof(args)),inline="always",cache = cache)(func)
sig = types.double(types.double,types.CPointer(nb.typeof(args)))
@cfunc(sig,error_model="numpy",cache = True)
def c_func(x,args_ptr):
return nb_func(x,args_ptr[0])
return c_func,sig
def gen_cfunc_with_args(func,args,cache=False):
"""
Generates a c-func with signature double(double, *some_tuple) of a Numba or Python function
Useful eg. for scipy.integrate.quad low level interface
As a second output a Numba function is generated, which generates the lowering result of a Tuple
"""
if not callable(func) and not isinstance(func,nb.core.registry.CPUDispatcher):
raise TypeError("The first argument func must be a function ")
if not isinstance(args, tuple):
raise TypeError("args must be a tuple")
if isinstance(func,nb.core.registry.CPUDispatcher):
func = func.py_func
@njit(cache = cache)
def func_gen_lowered_struct(tup):
size = sizeof_struct(tup)
ptr = cast_tuple_to_voidptr(tup)
view_to_tuple = nb.carray(ptr,(size,),dtype=np.uint8)
lowered =np.copy(view_to_tuple)
return lowered
nb_func = nb.njit(types.double(types.double,nb.typeof(args)),inline="always",cache = cache)(func)
@cfunc(types.double(types.double,types.CPointer(nb.typeof(args))),error_model="numpy",cache = cache)
def c_func(x,args_ptr):
return nb_func(x,args_ptr[0])
return c_func,func_gen_lowered_struct
@njit(inline="always")
def create_gsl_function(func_address, args):
"""Generates the struct (function_adress,pointer_to args)
This isn't save, because it assumes Numba and the C-Compiler have the same layout on struct(double*,double*)
"""
args_p = cast_tuple_to_voidptr(args) #cast the args tuple to voidptr
return cast_tuple_to_voidptr((func_address,args_p))
The acutal wrapper
from llvmlite import binding
from numba import types
import numba as nb
import gsl_intrinsics as gsl_intr
binding.load_library_permanently(r'path_to_gsl.dll')
#Shotcuts
dble = types.double
dble_p = types.CPointer(dble)
intc = types.intc
size_t = types.size_t
void_p = types.voidptr
#Binding functons###############
c_func_name = 'gsl_integration_qag'
c_sig = intc(void_p,dble,dble,
dble,dble,size_t,
intc,void_p,
dble_p,dble_p)
c_gsl_integration_qag = types.ExternalFunction(c_func_name, c_sig)
c_func_name = 'gsl_integration_workspace_alloc'
c_sig = void_p(size_t,)
c_gsl_integration_workspace_alloc = types.ExternalFunction(c_func_name, c_sig)
c_func_name = 'gsl_integration_workspace_free'
c_sig = types.none(void_p,)
c_gsl_integration_workspace_free = types.ExternalFunction(c_func_name, c_sig)
#################################
@nb.njit(cache=True)
def nb_gsl_qag(func_address,args, a, b, epsabs, epsrel, limit, key):
#generate pointer to args
args_p = gsl_intr.cast_tuple_to_voidptr(args) #cast the args tuple to voidptr
#generate the gsl function
gsl_func = gsl_intr.cast_tuple_to_voidptr((func_address,args_p))
#allocate GSL integration workspace
workspace = c_gsl_integration_workspace_alloc(limit)
#Allocate result and error on the stack
result_p = gsl_intr.val_to_ptr(nb.double(0))
error_p = gsl_intr.val_to_ptr(nb.double(0))
#call the fuction
ret = c_gsl_integration_qag (gsl_func,a,b,
epsabs, epsrel, limit,
key, workspace,
result_p, error_p)
#free GSL integration workspace
c_gsl_integration_workspace_free(workspace)
return gsl_intr.ptr_to_val(result_p),gsl_intr.ptr_to_val(error_p)
Using the wrapped function
import numba as nb
from numba import types
import numpy as np
import gsl_intrinsics as gsl_intr
import gsl as gsl
import math
#Generate a bit more complicated input a
args = (5.,np.ones((3,3)),8.)
@nb.njit()
def sin_over_x(x, args):
return math.sin(x) / x +args[2] if x != 0.0 else 1.0
func, sig = gsl_intr.gen_cfunc_and_signature(sin_over_x,args,cache=True) #-> simpler but caching is not workig
#with explicit hand written types caching is working
inside_tuple = (types.float64,types.Array(types.float64,2,'C'),types.float64)
@nb.cfunc(types.double(types.double,types.CPointer(types.Tuple(inside_tuple))),cache=True)
def func(x,args_p):
return sin_over_x(x, args_p[0])
a=1e-8
b=3.14
epsabs= 1.49e-08
epsrel= 1.49e-08
limit= 50
key = 1
print(gsl.nb_gsl_qag(func.address,args, a, b, epsabs, epsrel, limit, key))
print("Cache hits nb_gsl_qag:", gsl.nb_gsl_qag.stats.cache_hits)
print("Cache misses nb_gsl_qag:", gsl.nb_gsl_qag.stats.cache_misses)
print("")
print("Cache hits func:", func.cache_hits)