https://docs.kidger.site/equinox/api/nn/mlp/#equinox.nn.MLP
The only way I was able to use MLP is like this
import jax
import equinox as eqx
import numpy as np
jax.vmap(eqx.nn.MLP(in_size=12, out_size=4, width_size=6, depth=5, key=key))(np.random.randn(5, 12)
Is this the intended usage? It differs a bit from other frameworks then. But maybe safer.
Yup, this is intended!
Every layer in eqx.nn
acts on a single batch element, and you can apply them to batches by calling jax.vmap
, exactly as you're doing.
See also this FAQ entry: https://docs.kidger.site/equinox/faq/#how-do-i-input-higher-order-tensors-eg-with-batch-dimensions-into-my-model
I hope that helps!