In the documentation, JAX provides vectorization. However, aren't JAX operations already vectorized? For example, to add two vectors, I thought that the element-wise additions were vectorized internally already.
My guess is that vectorization is useful when: it's hard for us to add a dimension for broadcasting, so we resort to a more explicit vectorization.
EDIT: for example, instead of vectorizing convolution2d with different kernels, I simply stack the kernels, copy and stack the channel, then perform the convolution2d with this stack of kernels.
I have also raised a similar question here: https://github.com/jax-ml/jax/issues/26212 By now I think there is no universal answer to this and it will remain a matter of taste to a certain degree. However in some cases there is a clearer answer:
jnp.histogram
or jnp.bincount
, in this case you can use vmap
to get a "batched" version of that function (for example search for "batched_histogram" here http://axeldonath.com/jax-diffusion-models-pydata-boston-2025/). This is really convenient and avoids loops to improve readability as well as performance.vmap
works over PyTrees. Some libraries (most notably equinox) use this to avoid the need for handling a batch axis in models completely and just finally vmap
over the whole parameter tree by convention. This frees developers from thinking about the batch axis at all, but when working with equinox you have to stick to that convention. It also only works if operations are independent across different batches. It does not work for operations such as a "batch norm" (see also https://docs.kidger.site/equinox/examples/stateful/)vmap
(basically what you said).As broadcasting and batch axes are universally accepted convention in deep learning I mostly stick with them. But I rely on vmap
whenever there is no native vectorization, whenever I work with libraries that rely on vmap
by convention, or whenever I need to vectorize operations along non-conventional axes (basically everything except batch axis).