In numba, I want to pass the config variable to a function as a compile-time constant. Specifically what I want to do is
@njit
def physics(config):
flagA = config.flagA
flagB = config.flagB
aNumbaList = List()
for i in range(100):
if flagA:
aNumbaList.append(i)
else:
aNumbaList.append(i/10)
return aNumbaList
If the config variables are compile-time constants, this would have passed, but it is not, so it's giving me errors saying that there are two candidates
There are 2 candidate implementations:
- Of which 2 did not match due to:
...
...
I looked at one of numba meeting minutes and found that there was a way to do this Numba Meeting: 2024-03-05 I tried that, but it is still raising the same error.
Here is the code with the error message:
.. code:: ipython3
from numba import jit, types, njit
from numba.extending import overload
from numba.typed import List
import functools
.. code:: ipython3
class Config():
def __init__(self, flagA, flagB):
self._flagA = flagA
self._flagB = flagB
@property
def flagA(self):
return self._flagA
@property
def flagB(self):
return self._flagB
.. code:: ipython3
@functools.cache
def obj2strkeydict(obj, config_name):
# unpack object to freevars and close over them
tmp_a = obj.flagA
tmp_b = obj.flagB
assert isinstance(config_name, str)
tmp_force_heterogeneous = config_name
@njit
def configurator():
d = {'flagA': tmp_a,
'flagB': tmp_b,
'config_name': tmp_force_heterogeneous}
return d
# return a configuration function that returns a string-key-dict
# representation of the configuration object.
return configurator
.. code:: ipython3
@njit
def physics(cfig_func):
config = cfig_func()
flagA = config['flagA']
flagB = config['flagB']
aNumbaList = List()
for i in range(100):
if flagA:
aNumbaList.append(i)
else:
aNumbaList.append(i/10)
return aNumbaList
.. code:: ipython3
def demo():
configuration1 = Config(True, False)
jit_config1 = obj2strkeydict(configuration1, 'config1')
physics(jit_config1)
.. code:: ipython3
demo()
::
---------------------------------------------------------------------------
TypingError Traceback (most recent call last)
Cell In[83], line 1
----> 1 demo()
Cell In[82], line 4, in demo()
2 configuration1 = Config(True, False)
3 jit_config1 = obj2strkeydict(configuration1, 'config1')
----> 4 physics(jit_config1)
File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:468, in _DispatcherBase._compile_for_args(self, *args, **kws)
464 msg = (f"{str(e).rstrip()} \n\nThis error may have been caused "
465 f"by the following argument(s):\n{args_str}\n")
466 e.patch_message(msg)
--> 468 error_rewrite(e, 'typing')
469 except errors.UnsupportedError as e:
470 # Something unsupported is present in the user code, add help info
471 error_rewrite(e, 'unsupported_error')
File ~/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/dispatcher.py:409, in _DispatcherBase._compile_for_args.<locals>.error_rewrite(e, issue_type)
407 raise e
408 else:
--> 409 raise e.with_traceback(None)
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
- Resolution failure for literal arguments:
No implementation of function Function(<function impl_append at 0x7fd87d253920>) found for signature:
>>> impl_append(ListType[int64], float64)
There are 2 candidate implementations:
- Of which 2 did not match due to:
Overload in function 'impl_append': File: numba/typed/listobject.py: Line 592.
With argument(s): '(ListType[int64], float64)':
Rejected as the implementation raised a specific error:
TypingError: Failed in nopython mode pipeline (step: nopython frontend)
No implementation of function Function(<intrinsic _cast>) found for signature:
>>> _cast(float64, class(int64))
There are 2 candidate implementations:
- Of which 2 did not match due to:
Intrinsic in function '_cast': File: numba/typed/typedobjectutils.py: Line 22.
With argument(s): '(float64, class(int64))':
Rejected as the implementation raised a specific error:
TypingError: cannot safely cast float64 to int64. Please cast explicitly.
raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/typedobjectutils.py:75
During: resolving callee type: Function(<intrinsic _cast>)
During: typing of call at /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py (600)
File "../anaconda3/envs/tardis/lib/python3.11/site-packages/numba/typed/listobject.py", line 600:
def impl(l, item):
casteditem = _cast(item, itemty)
^
raised from /home/sam/anaconda3/envs/tardis/lib/python3.11/site-packages/numba/core/typeinfer.py:1086
- Resolution failure for non-literal arguments:
None
During: resolving callee type: BoundFunction((<class 'numba.core.types.containers.ListType'>, 'append') for ListType[int64])
During: typing of call at /tmp/ipykernel_9889/739598600.py (11)
File "../../../tmp/ipykernel_9889/739598600.py", line 11:
<source missing, REPL/exec in use?>
Any help or any reference to a related material would really help me. Thank You.
In Numba, global variables are compile-time constant so you can use that to do what you want. Here is an example:
import numba as nb # v0.58.1
flagA = True
@nb.njit
def physics(flagA):
aNumbaList = nb.typed.List()
for i in range(100):
if flagA:
aNumbaList.append(i)
else:
aNumbaList.append(i/10)
return aNumbaList
This works well without error while passing flagA
in parameter results in an error because the items in the if
and else
are of different types.
That being said, global variables are not great in term of software engineering, and you may want to compile the function for different configuration at runtime (e.g. based on an initialisation process, while avoiding writing in global variables).
An alternative solution is to return a function which read variable defined in a parent function so it is considered as a global one to the function and thus a compile-time constant. The variable read by the compiled function can be passed in parameter to the parent one. Here is an example:
import numba as nb
def make_physics(flagA):
@nb.njit
def fun():
aNumbaList = nb.typed.List()
for i in range(100):
if flagA:
aNumbaList.append(i)
else:
aNumbaList.append(i/10)
return aNumbaList
return fun
physics = make_physics(True) # Compile a specialized function every time it is called
physics() # Call the compiled function generated just before
This does not results in any error too and actually works as intended. Here is the generated assembly code of the physics
function showing that there is no runtime check of flagA
within the main loop:
[...]
movq %rax, %r12 ; r12 = an allocated Python object (the list?)
movq 24(%rax), %rax
movq %r14, (%rax)
xorl %ebx, %ebx ; i = 0
movabsq $NRT_incref, %r13
movabsq $numba_list_append, %rbp
leaq 48(%rsp), %r15 ; (r15 is a pointer on i)
.LBB0_6: ; Main loop
movq %r12, %rcx
callq *%r13 ; Call NRT_incref(r12)
movq %rbx, 48(%rsp)
movq %r14, %rcx
movq %r15, %rdx
callq *%rbp ; Call numba_list_append(r14, pointer_of(i))
testl %eax, %eax
jne .LBB0_7 ; Stop the loop if numba_list_append returned a non-zero value
incq %rbx ; i += 1
movq %r12, %rcx
movabsq $NRT_decref, %rax
callq *%rax ; Call NRT_decref(r12)
cmpq $100, %rbx
jne .LBB0_6 ; Loop as long as i < 100
[...]
Regarding the actual use-case, memoization and Numba function caching can help to avoid compiling the target function many times for the same configuration.