I am using jax in python, and I want to loop over some code for a random number of times. This is part of a function which is jit compiled later. I have a small example below which should explain what I want to do.
num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
arr = []
for i in range(num_iters):
arr += [i*i]
This works without any error and gives arr=[0,1,4]
at the end of the loop (with the fixed seed of 0
that we're using in PRNGKey
).
However, if this is part of a jit-compiled function:
@jax.jit
def do_stuff(start):
num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
arr = []
for i in range(num_iters):
arr += [i*i]
for value in arr:
start += value
return start
I get a TracerIntegerConversionError
on num_iters
. The function works fine without the jit decorator. How to get this to work with jit? I basically just want to construct the list arr
whose length depends on a random number. Alternatively, I can also use a list with the maximum possible size, but then I'd have to loop over it a random number of times.
Further context
It's possible to make it not throw an error using a numpy
random number generator instead:
@jax.jit
def do_stuff(start):
np_rng = np.random.default_rng()
num_iters = np_rng.integers(1, 10)
arr = []
for i in range(num_iters):
arr += [i*i]
for value in arr:
start += value
return start
However, this is not what I want. There is a jax rng
which is passed to my function which I wish to use to generate num_iters
. Otherwise, arr
always has the same length since the numpy
seed is fixed to what was available at jit-compile time, and I always get the same result without any randomness. However, if I use that rng
key as seed for numpy
(like np.random.default_rng(seed=rng[0])
) it again gives the following error:
TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=1/0)>
You could use jax.lax.fori_loop
for this:
import jax
@jax.jit
def do_stuff(start):
num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
return jax.lax.fori_loop(0, num_iters, lambda i, val: val + i * i, start)
print(do_stuff(10))
# 15