pythonpytorchjax

Jax vmap with lax scan having different sequence length in batch dimension


I have this following code , where my sim_timestep is in batch I am not able to run this since the lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) ) requires the concrete array , but since I have vmapped the state_predictor function the sim_timestep is being as a tracedArray . Any help would be greatly appreciated . Thanks all

from jax import random
from jax import lax
import jax
import jax.numpy as jnp
import pdb


def fwd_dynamics(x_u, xs):
    x0,uk =  x_u
    Delta_T = 0.001
    lwb = 1.2
    psi0=x0[2][0]
    v0= x0[3][0]
    vdot0 = uk[0][0]
    delta0 = uk[1][0]
    thetadot0 = uk[2][0]
        
    xdot= jnp.asarray([[v0*jnp.cos(psi0) ],
        [v0*jnp.sin(psi0)] ,
        [v0*jnp.tan(delta0)/(lwb)],
        [vdot0],
        [thetadot0]])
    x_next = x0 + xdot*Delta_T
    return (x_next,uk), x_next  # ("carryover", "accumulated")


def state_predictor( xk,uk ,sim_timestep):
    (x_next,_), _ = lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) )
    return x_next

low = 0  # Adjust minimum value as needed
high = 100  # Adjust maximum value as needed
key = jax.random.PRNGKey(44)

sim_time = jax.random.randint(key, shape=(10, 1), minval=low, maxval=high)

xk = jax.random.uniform(key, shape=(10,5, 1))
uk = jax.random.uniform(key, shape=(10,2, 1))

state_predictor_vmap = jax.jit(jax.vmap(state_predictor,in_axes= 0 ,out_axes=0 ))
x_next = state_predictor_vmap( xk,uk ,sim_time)
print(x_next.shape)

I tried to solve it by above code , hoping to get alternative way to achieve the same functionality.


Solution

  • What you're asking to do is impossible: scan lengths must be static, and vmapped values are non-static by definition.

    What you can do instead is replace your scan with a fori_loop or a while_loop, and then the loop boundary does not need to be static. For example, if you implement your function this way and leave the rest of your code unchanged, it should work:

    def state_predictor(xk, uk, sim_timestep):
      body_fun = lambda i, x_u: fwd_dynamics(x_u, i)[0]
      x_next, _ = lax.fori_loop(0, sim_timestep[0], body_fun, (xk, uk))
      return x_next