I have this following code , where my sim_timestep is in batch I am not able to run this since the lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) ) requires the concrete array , but since I have vmapped the state_predictor function the sim_timestep is being as a tracedArray . Any help would be greatly appreciated . Thanks all
from jax import random
from jax import lax
import jax
import jax.numpy as jnp
import pdb
def fwd_dynamics(x_u, xs):
x0,uk = x_u
Delta_T = 0.001
lwb = 1.2
psi0=x0[2][0]
v0= x0[3][0]
vdot0 = uk[0][0]
delta0 = uk[1][0]
thetadot0 = uk[2][0]
xdot= jnp.asarray([[v0*jnp.cos(psi0) ],
[v0*jnp.sin(psi0)] ,
[v0*jnp.tan(delta0)/(lwb)],
[vdot0],
[thetadot0]])
x_next = x0 + xdot*Delta_T
return (x_next,uk), x_next # ("carryover", "accumulated")
def state_predictor( xk,uk ,sim_timestep):
(x_next,_), _ = lax.scan(fwd_dynamics, (xk,uk) ,jnp.arange(sim_timestep) )
return x_next
low = 0 # Adjust minimum value as needed
high = 100 # Adjust maximum value as needed
key = jax.random.PRNGKey(44)
sim_time = jax.random.randint(key, shape=(10, 1), minval=low, maxval=high)
xk = jax.random.uniform(key, shape=(10,5, 1))
uk = jax.random.uniform(key, shape=(10,2, 1))
state_predictor_vmap = jax.jit(jax.vmap(state_predictor,in_axes= 0 ,out_axes=0 ))
x_next = state_predictor_vmap( xk,uk ,sim_time)
print(x_next.shape)
I tried to solve it by above code , hoping to get alternative way to achieve the same functionality.
What you're asking to do is impossible: scan
lengths must be static, and vmapped values are non-static by definition.
What you can do instead is replace your scan
with a fori_loop
or a while_loop
, and then the loop boundary does not need to be static. For example, if you implement your function this way and leave the rest of your code unchanged, it should work:
def state_predictor(xk, uk, sim_timestep):
body_fun = lambda i, x_u: fwd_dynamics(x_u, i)[0]
x_next, _ = lax.fori_loop(0, sim_timestep[0], body_fun, (xk, uk))
return x_next