At some point in my code, I call a Numba function and all subsequent computations are made with Numba jitted functions until post-processing steps.
Over the past days, I have been looking for an efficient way to send to the Numba part of the code all the variables (booleans, integers, floats, and float arrays mostly) while trying to keep the code readable and clear. In my case, that implies limiting the number of arguments and, if possible, regrouping some variables depending on the system they refer to.
I identified four ways to do this:
cache=True
option. This is a dealbreaker for me as compilation times may exceed the execution of the code itself.namedtuples
, if an object from a @jitclass
is initialized within a non jitted function, I observed that it becomes impossible to benefit from the cache=True
option (see this post).In the end, none of these four alternatives allow me to do what I wanted to. I am probably missing something...
Here is what I did in the end : I combined the use of regular python classes and Numba @jitclass
in order to maintain the possibility to benefit from the cache=True
option.
Here is my mwe:
import numba as nb
from numba import jit
from numba.experimental import jitclass
spec_cls = [
('a', nb.types.float64),
('b', nb.types.float64),
]
# python class
class ClsWear_py(object):
def __init__(self, a, b):
self.a = a
self.b = b
# mirror Numba class
@jitclass(spec_cls)
class ClsWear(object):
def __init__(self, a, b):
self.a = a
self.b = b
def function_python(obj):
print('from the python class :', obj.a)
# call of a Numba function => this is where I must list explicitly all the keys of the python class object
oa, ob = function_numba(obj.a, obj.b)
return obj, oa, ob
@jit(nopython=True)
def function_numba(oa, ob):
# at the beginning of the Numba function, the arguments are used to define the @jitclass object
obj_nb = ClsWear(oa, ob)
print('from the numba class :', obj_nb.a)
return obj_nb.a, obj_nb.b
# main code :
obj_py = ClsWear_py(11,22)
obj_rt, a, b = function_python(obj_py)
The output of this code is :
$ python mwe.py
from the python class : 11
from the numba class : 11.0
On the plus side :
cache=True
is workingBut on the down side :
Am I missing something ? Is there a more obvious way to do this ?
Based on your research (which I agree with), there doesn't seem to be an intuitive way to do this. So I'm going to suggest the least horrible workaround.
The idea is to pass it as a tuple to the jit function and then convert it to the desired class in the jit function.
from typing import NamedTuple
from numba import njit
class Config(NamedTuple):
a: int
b: float
c: str
@njit(cache=True)
def f(config_values):
config = Config(*config_values) # Convert to a named tuple.
return config.a
def main():
config = Config(1, 2.0, "3")
f(tuple(config)) # Pass as a tuple.
print("Cache:", f.stats)
main()
Result (2nd run):
Cache: _CompileStats(cache_path=..., cache_hits=Counter({(Tuple((int64, float64, unicode_type)),): 1}), cache_misses=Counter())
As you can see, it is correctly cached as a tuple.
One of the issues with this workaround is that named tuples are read-only. So you cannot modify fields.
In that case, you could do the same thing with a jitclass.
Also note that you can create a mirror class for numba from its Python class counterpart, because @jitclass
is a regular Python decorator.
from numba import njit
from numba.experimental import jitclass
class Config:
# These type hints are interpreted as the default specs for jitclass.
# If you only use primitive types, this should be sufficient.
a: int
b: float
c: str
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
def as_tuple(self):
return self.a, self.b, self.c
JitConfig = jitclass(Config)
# _JitConfig = jitclass(specs)(Config) if you need to specify the specs.
@njit(cache=True)
def f(config_values):
config = JitConfig(*config_values) # Convert to a jitclass.
return config.a
def main():
# Since we use a tuple as an argument, it is not mandatory to use a jitclass here.
config = Config(1, 2.0, "3")
f(config.as_tuple()) # Pass as a tuple.
print("Cache:", f.stats)
main()
You can also use overload
to hide the jitclass entirely.
Notice that in the code below, it is no longer necessary to explicitly use the jitclass.
from numba import njit
from numba.core.extending import overload
from numba.experimental import jitclass
class Config:
a: int
b: float
c: str
def __init__(self, a, b, c):
self.a = a
self.b = b
self.c = c
def as_tuple(self):
return self.a, self.b, self.c
_JitConfig = jitclass(Config)
@overload(Config, strict=False)
def overload_config_init(*args):
def jit_config_init(*args):
return _JitConfig(*args)
return jit_config_init
@njit(cache=True)
def f(config_values):
# Since it will be overloaded to a jitclass, you can use a Python class here.
config = Config(*config_values)
return config.a
def main():
config = Config(1, 2.0, "3")
f(config.as_tuple())
print("Cache:", f.stats)
main()
As for performance, it should be negligible. Here is the benchmark:
import timeit
from typing import NamedTuple
from numba import njit
from numba.core.extending import overload
from numba.experimental import jitclass
class NamedTupleContainer(NamedTuple):
a0: int
a1: int
a2: int
a3: int
a4: int
a5: float
a6: float
a7: float
a8: float
a9: float
class JitclassContainer:
a0: int
a1: int
a2: int
a3: int
a4: int
a5: float
a6: float
a7: float
a8: float
a9: float
def __init__(self, a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
self.a0 = a0
self.a1 = a1
self.a2 = a2
self.a3 = a3
self.a4 = a4
self.a5 = a5
self.a6 = a6
self.a7 = a7
self.a8 = a8
self.a9 = a9
def as_tuple(self):
return (
self.a0,
self.a1,
self.a2,
self.a3,
self.a4,
self.a5,
self.a6,
self.a7,
self.a8,
self.a9,
)
_JitclassContainer = jitclass(JitclassContainer)
@overload(JitclassContainer)
def overload_container_init(*args):
def jit_container_init(*args):
return _JitclassContainer(*args)
return jit_container_init
@njit(cache=True)
def f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9):
return a0
@njit(cache=True)
def f_namedtuple(args):
c = NamedTupleContainer(*args)
return c.a0
@njit(cache=True)
def f_jitclass(args):
c = JitclassContainer(*args)
return c.a0
def main():
def benchmark(f):
n_runs = 10000
return min(timeit.repeat(f, repeat=100, number=n_runs)) / n_runs
values = 1, 2, 3, 4, 5, 6.0, 7.0, 8.0, 9.0, 10.0
a0, a1, a2, a3, a4, a5, a6, a7, a8, a9 = values
named_tuple_container = NamedTupleContainer(*values)
jitclass_container = JitclassContainer(*values)
t = benchmark(lambda: f_multi_args(a0, a1, a2, a3, a4, a5, a6, a7, a8, a9))
print(f"f_multi_args: {t * 10 ** 9:.0f} ns")
t = benchmark(lambda: f_namedtuple(tuple(named_tuple_container)))
print(f"f_namedtuple: {t * 10 ** 9:.0f} ns")
t = benchmark(lambda: f_jitclass(jitclass_container.as_tuple()))
print(f"f_jitclass : {t * 10 ** 9:.0f} ns")
main()
Result:
f_multi_args: 525 ns
f_namedtuple: 695 ns
f_jitclass : 652 ns
On my PC, the difference was less than 200 nanoseconds per function call with 10 arguments/fields.