I'm new to JAX and reading the docs i found that jitted functions should not contain iterators (section on pure functions)
and they bring this example:
import jax.numpy as jnp
import jax.lax as lax
from jax import jit
# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0
trying to fiddling with it a little bit in order to see if i can get directly an error instead of undefined behaviour i wrote
@jit
def f(x, arr):
for i in range(10):
x += arr[i]
return x
@jit
def f1(x, arr):
it = iter(arr)
for i in range(10):
x += next(it)
return x
print(f(0,array)) # 45 as expected
print(f1(0,array)) # still 45
Is it a "chance" that the jitted function f1() now shows the correct behaviour?
Your code works because of the way that JAX's tracing model works. When JAX's tracing encounters Python control flow, like for
loops, the loop is fully evaluated at trace-time (There's some exploration of this in JAX Sharp Bits: Control Flow).
Because of this, your use of an iterator in this context is fine, because every iteration is evaluated at trace-time, and so next(it)
is re-evaluated at every iteration.
In contrast, when using lax.fori_loop
, next(iterator)
is only executed a single time and its output is treated as a trace-time constant that will not change during the runtime iterations.