pythonrandomjitjax

How to loop a random number of times in jax with jit compilation?


I am using jax in python, and I want to loop over some code for a random number of times. This is part of a function which is jit compiled later. I have a small example below which should explain what I want to do.

num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
arr = []
for i in range(num_iters):
  arr += [i*i]

This works without any error and gives arr=[0,1,4] at the end of the loop (with the fixed seed of 0 that we're using in PRNGKey).

However, if this is part of a jit-compiled function:

@jax.jit
def do_stuff(start):
  num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
  arr = []
  for i in range(num_iters):
    arr += [i*i]
  for value in arr:
    start += value
  return start

I get a TracerIntegerConversionError on num_iters. The function works fine without the jit decorator. How to get this to work with jit? I basically just want to construct the list arr whose length depends on a random number. Alternatively, I can also use a list with the maximum possible size, but then I'd have to loop over it a random number of times.


Further context

It's possible to make it not throw an error using a numpy random number generator instead:

@jax.jit
def do_stuff(start):
  np_rng = np.random.default_rng()
  num_iters = np_rng.integers(1, 10)
  arr = []
  for i in range(num_iters):
    arr += [i*i]
  for value in arr:
    start += value
  return start

However, this is not what I want. There is a jax rng which is passed to my function which I wish to use to generate num_iters. Otherwise, arr always has the same length since the numpy seed is fixed to what was available at jit-compile time, and I always get the same result without any randomness. However, if I use that rng key as seed for numpy (like np.random.default_rng(seed=rng[0])) it again gives the following error:

TypeError: SeedSequence expects int or sequence of ints for entropy not Traced<ShapedArray(uint32[])>with<DynamicJaxprTrace(level=1/0)>

Solution

  • You could use jax.lax.fori_loop for this:

    import jax
    
    @jax.jit
    def do_stuff(start):
      num_iters = jax.random.randint(jax.random.PRNGKey(0), (1,), 1, 10)[0]
      return jax.lax.fori_loop(0, num_iters, lambda i, val: val + i * i, start)
    
    print(do_stuff(10))
    # 15