pythonvectorizationjax

Are JAX operations already vectorized?


In the documentation, JAX provides vectorization. However, aren't JAX operations already vectorized? For example, to add two vectors, I thought that the element-wise additions were vectorized internally already.

My guess is that vectorization is useful when: it's hard for us to add a dimension for broadcasting, so we resort to a more explicit vectorization.

EDIT: for example, instead of vectorizing convolution2d with different kernels, I simply stack the kernels, copy and stack the channel, then perform the convolution2d with this stack of kernels.


Solution

  • I have also raised a similar question here: https://github.com/jax-ml/jax/issues/26212 By now I think there is no universal answer to this and it will remain a matter of taste to a certain degree. However in some cases there is a clearer answer:

    As broadcasting and batch axes are universally accepted convention in deep learning I mostly stick with them. But I rely on vmap whenever there is no native vectorization, whenever I work with libraries that rely on vmap by convention, or whenever I need to vectorize operations along non-conventional axes (basically everything except batch axis).