I was suprised to see that depending on the size of an input matrix, which is vmapped over inside of a function, the output of the function changes slightly. That is, not only does the size of the output change (which is what I would expect from vmapping) but also the numerics changed slightly. (Note that this only occurs in float32
and only on the GPU)
I wrote a minimally reproducible example to illustrate the behaviour:
import jax
import jax.numpy as jnp
import equinox as eqx
def equinox_vmap(x, mlp):
out = eqx.filter_vmap(mlp.__call__)(x)
return out
key = jax.random.PRNGKey(0)
key, network_key = jax.random.split(key, 2)
mlp = eqx.nn.MLP(2, 2, 10, 2, key=network_key)
key, key_x = jax.random.split(key, 2)
x = jax.random.normal(key_x, (10000, 2))
error_eqx = equinox_vmap(x[:10], mlp) - equinox_vmap(x, mlp)[:10]
print("eqx error:", error_eqx)
When running this example I get the output:
eqx error: [[-1.4442205e-04 1.0999292e-04]
[-5.9515238e-05 -9.1716647e-06]
[ 1.4841557e-05 5.6132674e-05]
[ 0.0000000e+00 0.0000000e+00]
[-9.1642141e-06 -2.5466084e-05]
[ 3.8832426e-05 -3.3110380e-05]
[ 3.3825636e-05 -2.4946406e-05]
[ 4.0918589e-05 -3.2216311e-05]
[ 1.3601780e-04 8.7693334e-06]
[ 0.0000000e+00 0.0000000e+00]]
I understand that the numerics of float32
are not fully accurate and some error is to be expected. However, I was suprised that the result changes depending on how much of the input array is put into the function. I was expecting that the first row of the x
array, i.e., x[0,:]
would still be filled with the same values and therefore the first row in the output would be the same.
Further notes:
float64
(jax.config.update("jax_enable_x64", False)
) which completely removed this from occuring. I understand that this is a numerical problem, but I am a little bit confused how the vmapping interacts with the example.jax.config.update("jax_platform_name", "cpu")
) this problem also disappears which I also find difficult to understand.Questions:
Setup:
equinox 0.13.0
jax 0.7.0
jax-cuda12-pjrt 0.7.0
jax-cuda12-plugin 0.7.0
jaxlib 0.7.0
jaxtyping 0.3.2
ml_dtypes 0.5.3
numpy 2.3.2
nvidia-cublas-cu12 12.9.1.4
nvidia-cuda-cupti-cu12 12.9.79
nvidia-cuda-nvcc-cu12 12.9.86
nvidia-cuda-nvrtc-cu12 12.9.86
nvidia-cuda-runtime-cu12 12.9.79
nvidia-cudnn-cu12 9.11.0.98
nvidia-cufft-cu12 11.4.1.4
nvidia-cusolver-cu12 11.7.5.82
nvidia-cusparse-cu12 12.5.10.65
nvidia-nccl-cu12 2.27.6
nvidia-nvjitlink-cu12 12.9.86
nvidia-nvshmem-cu12 3.3.9
opt_einsum 3.4.0
pip 24.0
scipy 1.16.1
setuptools 65.5.0
typing_extensions 4.14.1
wadler_lindig 0.1.7
Any explanations are greatly appreachiated.
This is behaving as expected. This is not fundamentally about vmap
; this is about floating point math. Whenever you're doing floating point operations, you will accumulate rounding errors, and when you do the "same" computation in two different ways, you will accumulate rounding errors differently (see Is floating-point math broken? for some discussion of this).
Running vmap
over different batch sizes results in different sequences of operations, which in turn results in different rounding errors.
As for why this differs between CPU and GPU, it's all about how the floating point operations are sequenced. CPU is a serial architecture, so it's likely computing matrix products row-by-row with the same accumulation orders regardless of input size. GPU is a parallel architecture, and will generally distribute and accumulate results differently depending on the size of the inputs.