I'm trying to understand JAX's auto-vectorization capabilities using vmap
and implemented a minimal working example based on JAX's documentation.
I don't understand how in_axes
is used correctly. In the example below I can set in_axes=(None, 0)
or in_axes=(None, 1)
leading to the same results. Why is that the case?
And why do I have to use in_axes=(None, 0)
and not something like in_axes=(0, )
?
import jax.numpy as jnp
from jax import vmap
def predict(params, input_vec):
assert input_vec.ndim == 1
activations = input_vec
for W, b in params:
outputs = jnp.dot(W, activations) + b
activations = jnp.tanh(outputs)
return outputs
if __name__ == "__main__":
# Parameters
dims = [2, 3, 5]
input_dims = dims[0]
batch_size = 2
# Weights
params = list()
for dims_in, dims_out in zip(dims, dims[1:]):
params.append((jnp.ones((dims_out, dims_in)), jnp.ones((dims_out,))))
# Input data
input_batch = jnp.ones((batch_size, input_dims))
# With vmap
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)
print(predictions)
in_axes=(None, 0)
means that the first argument (here params
) will not be mapped, while the second argument (here input_vec
) will be mapped along axis 0.
In the example below I can set
in_axes=(None, 0)
orin_axes=(None, 1)
leading to the same results. Why is that the case?
This is because input_vec
is a 2x2 matrix of ones, so whether you map along axis 0 or axis 1, the input vectors are length-2 vectors of ones. In more general cases, the two specifications are not equivalent, which you can see by either (1) making batch_size
differ from input_dims[0]
, or (2) filling your arrays with non-constant values.
why do I have to use
in_axes=(None, 0)
and not something likein_axes=(0, )
?
If you set in_axes=(0, )
for a function with two arguments, you get an error because the length of the in_axes
tuple must match the number of arguments passed to the function. That said, it is possible to pass a scalar in_axes=0
as a shorthand for in_axes=(0, 0)
, but for your function this would lead to a shape error because the leading dimension of the arrays in params
does not match the leading dimension of input_vec
.