I'm writing environment for rl agent training.
My env.step method takes as action array with length 3
def scan(self, f, init, xs, length=None):
if xs is None:
xs = [None] * length
carry = init
ys = []
for x in xs:
carry, y = f(carry, x)
ys.append(y)
return carry, np.stack(ys)
def step_env(
self,
key: chex.PRNGKey,
state: EnvState,
action: Union[int, float, chex.Array],
params: EnvParams,
) -> Tuple[chex.Array, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
c_action = jnp.clip(action,
params.min_action,
params.max_action)
_, m1 = self.scan(self.Rx, 0, action[0])
_, m2 = self.scan(self.Rx, 0, action[1])
_, m3 = self.scan(self.Rx, 0, action[2])
I vectorize the env.step using and then call it
obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(rng_step,
env_state,
action,
env_params)
But I got error
Cell In[9], line 65, in PCJ1_0.scan(self, f, init, xs, length)
63 ys = []
64 print(xs)
---> 65 for x in xs:
66 carry, y = f(carry, x)
67 ys.append(y)
[... skipping hidden 1 frame]
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/lax/lax.py:1592, in _iter(tracer)
1590 def _iter(tracer):
1591 if tracer.ndim == 0:
-> 1592 raise TypeError("iteration over a 0-d array") # same as numpy error
1593 else:
1594 n = int(tracer.shape[0])
TypeError: iteration over a 0-d array
How is it possible? If I plot the action array in the scan function I got array with length 5 (I vectored env.step for 5 envs), the length!=0
Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
val = Array([-0.25605989, -0.27983692, -1.0055736 , -0.4460616 , -0.8323701 ], dtype=float32)
batch_dim = 0
When you print your value, it gives this:
Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
val = Array([-0.25605989, -0.27983692, -1.0055736 , -0.4460616 , -0.8323701 ], dtype=float32)
batch_dim = 0
Here float32[]
tells you that this is a tracer with dtype float32
and shape []
: that is, your array is zero-dimensional within the context of the vmapped function.
The purpose of vmap
is to efficiently map a function over an axis of an array, so that within the function evaluation the array has one less dimension than it does outside the vmapped context. You can see that this way:
>>> import jax
>>> def f(x):
... print(f"{x.shape=}")
... print(f"{x=}")
...
>>> x = jax.numpy.arange(4.0)
>>> f(x)
x.shape=(4,)
x=Array([0., 1., 2., 3.], dtype=float32)
>>> jax.vmap(f)(x)
x.shape=()
x=Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
val = Array([0., 1., 2., 3.], dtype=float32)
batch_dim = 0
If you're passing a 1D input into your function and you want to manipulate the full 1D array within your function (instead of evaluating the function element-by-element), then it sounds like you should remove the vmap
.