pythonnumba

How to re-cache global variable?


I'm trying to change a global state (an opensimplex random seed), which I use in a function called from an @njited function. But once compiled, Numba fixes the global value and I'm not able to change it. Like so:

from numba import njit

global_var = 3

@njit
def func():
    return global_var - 3

print(func()) # prints 0
global_var = 5
print(func()) # prints 0 again, undesired

Is there some way I can change a global state? I've tried with closures and storing the state in a numba.experimental.jitclass, but couldn't get it to work. Here's an example of that:

from numba import njit, int32
from numba.experimental import jitclass

spec = [
    ('var', int32),
]
@jitclass(spec)
class State:
    def __init__(self, var):
        self.var = var
    def set(self, var):
        self.var = var

state = State(1)

@njit
def get_state():
    return state.var

get_state() # throws error

The class attempt gets me a Numba not implemented error:

Traceback (most recent call last):
  File "D:\rnd\py\landslip\try-jit.py", line 20, in <module>
    get_state()
  File "D:\prog\Python311\Lib\site-packages\numba\core\dispatcher.py", line 468, in _compile_for_args
    error_rewrite(e, 'typing')
  File "D:\prog\Python311\Lib\site-packages\numba\core\dispatcher.py", line 409, in error_rewrite
    raise e.with_traceback(None)
numba.core.errors.NumbaNotImplementedError: Failed in nopython mode pipeline (step: native lowering)
<numba.core.base.OverloadSelector object at 0x00000206A85AC7D0>, (instance.jitclass.State#206a825e050<var:int32>,)
During: lowering "$4load_global.0 = global(state: <numba.experimental.jitclass.boxing.State object at 0x00000206A8597DC0>)" at D:\rnd\py\landslip\try-jit.py (18)

Solution

  • Is it practical to wrap the function so the value is retrieved outside of it and passed as an argument?

    from numba import njit
    
    global_var = 3
    
    @njit
    def func_inner(local_var):
        return local_var - 3
    
    def func():
        return func_inner(global_var)
    
    print(func()) # prints 0
    global_var = 5
    print(func()) # prints 2