pythonfloating-pointgpujax

Is it expected that vmapping over different input sizes for the same function impacts the accuracy of the result?


I was suprised to see that depending on the size of an input matrix, which is vmapped over inside of a function, the output of the function changes slightly. That is, not only does the size of the output change (which is what I would expect from vmapping) but also the numerics changed slightly. (Note that this only occurs in float32 and only on the GPU)

I wrote a minimally reproducible example to illustrate the behaviour:

import jax
import jax.numpy as jnp
import equinox as eqx

def equinox_vmap(x, mlp):
    out = eqx.filter_vmap(mlp.__call__)(x)
    return out

key = jax.random.PRNGKey(0)
key, network_key = jax.random.split(key, 2)
mlp = eqx.nn.MLP(2, 2, 10, 2, key=network_key)

key, key_x = jax.random.split(key, 2)
x = jax.random.normal(key_x, (10000, 2))

error_eqx = equinox_vmap(x[:10], mlp) - equinox_vmap(x, mlp)[:10]
print("eqx error:", error_eqx)

When running this example I get the output:

eqx error: [[-1.4442205e-04  1.0999292e-04]
 [-5.9515238e-05 -9.1716647e-06]
 [ 1.4841557e-05  5.6132674e-05]
 [ 0.0000000e+00  0.0000000e+00]
 [-9.1642141e-06 -2.5466084e-05]
 [ 3.8832426e-05 -3.3110380e-05]
 [ 3.3825636e-05 -2.4946406e-05]
 [ 4.0918589e-05 -3.2216311e-05]
 [ 1.3601780e-04  8.7693334e-06]
 [ 0.0000000e+00  0.0000000e+00]]

I understand that the numerics of float32 are not fully accurate and some error is to be expected. However, I was suprised that the result changes depending on how much of the input array is put into the function. I was expecting that the first row of the x array, i.e., x[0,:] would still be filled with the same values and therefore the first row in the output would be the same.

Further notes:

Questions:

Setup:

equinox                  0.13.0
jax                      0.7.0
jax-cuda12-pjrt          0.7.0
jax-cuda12-plugin        0.7.0
jaxlib                   0.7.0
jaxtyping                0.3.2
ml_dtypes                0.5.3
numpy                    2.3.2
nvidia-cublas-cu12       12.9.1.4
nvidia-cuda-cupti-cu12   12.9.79
nvidia-cuda-nvcc-cu12    12.9.86
nvidia-cuda-nvrtc-cu12   12.9.86
nvidia-cuda-runtime-cu12 12.9.79
nvidia-cudnn-cu12        9.11.0.98
nvidia-cufft-cu12        11.4.1.4
nvidia-cusolver-cu12     11.7.5.82
nvidia-cusparse-cu12     12.5.10.65
nvidia-nccl-cu12         2.27.6
nvidia-nvjitlink-cu12    12.9.86
nvidia-nvshmem-cu12      3.3.9
opt_einsum               3.4.0
pip                      24.0
scipy                    1.16.1
setuptools               65.5.0
typing_extensions        4.14.1
wadler_lindig            0.1.7

Any explanations are greatly appreachiated.


Solution

  • This is behaving as expected. This is not fundamentally about vmap; this is about floating point math. Whenever you're doing floating point operations, you will accumulate rounding errors, and when you do the "same" computation in two different ways, you will accumulate rounding errors differently (see Is floating-point math broken? for some discussion of this).

    Running vmap over different batch sizes results in different sequences of operations, which in turn results in different rounding errors.

    As for why this differs between CPU and GPU, it's all about how the floating point operations are sequenced. CPU is a serial architecture, so it's likely computing matrix products row-by-row with the same accumulation orders regardless of input size. GPU is a parallel architecture, and will generally distribute and accumulate results differently depending on the size of the inputs.