jaxgymnasium

Zero length error of non-zero length array


I'm writing environment for rl agent training.

My env.step method takes as action array with length 3

    def scan(self, f, init, xs, length=None):
        if xs is None:
            xs = [None] * length
        carry = init
        ys = []

        for x in xs:
            carry, y = f(carry, x)
            ys.append(y)
        return carry, np.stack(ys)

    def step_env(
        self,
        key: chex.PRNGKey,
        state: EnvState,
        action: Union[int, float, chex.Array],
        params: EnvParams,
    ) -> Tuple[chex.Array, EnvState, jnp.ndarray, jnp.ndarray, Dict[Any, Any]]:
        
        c_action = jnp.clip(action,
                          params.min_action, 
                          params.max_action)
        
        _, m1 = self.scan(self.Rx, 0, action[0])
        _, m2 = self.scan(self.Rx, 0, action[1])
        _, m3 = self.scan(self.Rx, 0, action[2])

I vectorize the env.step using and then call it

obsv, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))(rng_step,
                                                                                          env_state,
                                                                                          action,
                                                                                          env_params)

But I got error

Cell In[9], line 65, in PCJ1_0.scan(self, f, init, xs, length)
     63 ys = []
     64 print(xs)
---> 65 for x in xs:
     66     carry, y = f(carry, x)
     67     ys.append(y)

    [... skipping hidden 1 frame]

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/jax/_src/lax/lax.py:1592, in _iter(tracer)
   1590 def _iter(tracer):
   1591   if tracer.ndim == 0:
-> 1592     raise TypeError("iteration over a 0-d array")  # same as numpy error
   1593   else:
   1594     n = int(tracer.shape[0])

TypeError: iteration over a 0-d array

How is it possible? If I plot the action array in the scan function I got array with length 5 (I vectored env.step for 5 envs), the length!=0

Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
  val = Array([-0.25605989, -0.27983692, -1.0055736 , -0.4460616 , -0.8323701 ],      dtype=float32)
  batch_dim = 0

Solution

  • When you print your value, it gives this:

    Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
      val = Array([-0.25605989, -0.27983692, -1.0055736 , -0.4460616 , -0.8323701 ],      dtype=float32)
      batch_dim = 0
    

    Here float32[] tells you that this is a tracer with dtype float32 and shape []: that is, your array is zero-dimensional within the context of the vmapped function.

    The purpose of vmap is to efficiently map a function over an axis of an array, so that within the function evaluation the array has one less dimension than it does outside the vmapped context. You can see that this way:

    >>> import jax
    
    >>> def f(x):
    ...  print(f"{x.shape=}")
    ...  print(f"{x=}")
    ...
    >>> x = jax.numpy.arange(4.0)
    
    >>> f(x)
    x.shape=(4,)
    x=Array([0., 1., 2., 3.], dtype=float32)
    
    >>> jax.vmap(f)(x)
    x.shape=()
    x=Traced<ShapedArray(float32[])>with<BatchTrace(level=1/0)> with
      val = Array([0., 1., 2., 3.], dtype=float32)
      batch_dim = 0
    

    If you're passing a 1D input into your function and you want to manipulate the full 1D array within your function (instead of evaluating the function element-by-element), then it sounds like you should remove the vmap.