jaxauto-vectorization

in_axes keyword in JAX's vmap


I'm trying to understand JAX's auto-vectorization capabilities using vmap and implemented a minimal working example based on JAX's documentation.

I don't understand how in_axes is used correctly. In the example below I can set in_axes=(None, 0) or in_axes=(None, 1) leading to the same results. Why is that the case?

And why do I have to use in_axes=(None, 0) and not something like in_axes=(0, )?

import jax.numpy as jnp
from jax import vmap


def predict(params, input_vec):
    assert input_vec.ndim == 1
    activations = input_vec
    for W, b in params:
        outputs = jnp.dot(W, activations) + b
        activations = jnp.tanh(outputs)
    return outputs


if __name__ == "__main__":

    # Parameters
    dims = [2, 3, 5]
    input_dims = dims[0]
    batch_size = 2

    # Weights
    params = list()
    for dims_in, dims_out in zip(dims, dims[1:]):
        params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))

    # Input data
    input_batch = jnp.ones((batch_size, input_dims))

    # With vmap
    predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
    print(predictions)

Solution

  • in_axes=(None, 0) means that the first argument (here params) will not be mapped, while the second argument (here input_vec) will be mapped along axis 0.

    In the example below I can set in_axes=(None, 0) or in_axes=(None, 1) leading to the same results. Why is that the case?

    This is because input_vec is a 2x2 matrix of ones, so whether you map along axis 0 or axis 1, the input vectors are length-2 vectors of ones. In more general cases, the two specifications are not equivalent, which you can see by either (1) making batch_size differ from input_dims[0], or (2) filling your arrays with non-constant values.

    why do I have to use in_axes=(None, 0) and not something like in_axes=(0, )?

    If you set in_axes=(0, ) for a function with two arguments, you get an error because the length of the in_axes tuple must match the number of arguments passed to the function. That said, it is possible to pass a scalar in_axes=0 as a shorthand for in_axes=(0, 0), but for your function this would lead to a shape error because the leading dimension of the arrays in params does not match the leading dimension of input_vec.