equinoxjax

Does equinox (jax) do no batch dim broadcasting and expects you to use vmap instead?


https://docs.kidger.site/equinox/api/nn/mlp/#equinox.nn.MLP

The only way I was able to use MLP is like this

import jax
import equinox as eqx
import numpy as np


jax.vmap(eqx.nn.MLP(in_size=12, out_size=4, width_size=6, depth=5, key=key))(np.random.randn(5, 12)

Is this the intended usage? It differs a bit from other frameworks then. But maybe safer.


Solution

  • Yup, this is intended!

    Every layer in eqx.nn acts on a single batch element, and you can apply them to batches by calling jax.vmap, exactly as you're doing.

    See also this FAQ entry: https://docs.kidger.site/equinox/faq/#how-do-i-input-higher-order-tensors-eg-with-batch-dimensions-into-my-model

    I hope that helps!