gpureinforcement-learningjaxmulti-agent-reinforcement-learning

Reproducibility of JAX calculations


I am using JAX in running Reinforcement Learning (RL) & Multi-Agent Reinforcement Learning (MARL) calculations. I have noticed the following behaviour:

os.environ["XLA_FLAGS"] = "--xla_gpu_deterministic_ops"

os.environ["JAX_DISABLE_MOST_FASTER_PATHS"] = "1"

Unfortunately, this measure significantly reduces the computation speed.

Any idea about dealing with this issue?


Solution

  • On an accelerator like GPU, there will generally be a tradeoff between strict bit-wise reproducibility and speed of computation.

    Why is this? Fundamentally, this is because of the fact that floating point arithmetic only approximates real arithmetic, and so the order in which operations are executed can change the results, and order of operations is a degree of freedom that the GPU can exploit to execute code faster.

    As a simple example, consider summing the same array in different orders:

    In [1]: import numpy as np
    
    In [2]: rng = np.random.default_rng(0)
    
    In [3]: x = rng.normal(size=10000).astype('float32')
    
    In [4]: x.sum()
    Out[4]: np.float32(63.11888)
    
    In [5]: x[::-1].sum()
    Out[5]: np.float32(63.118877)
    

    The results differ slightly.

    This is relevant to your question because of the way a GPU works: GPUs do fast vector operations by automatically running them in parallel. So, for example, to compute a sum, it might chunk the array across N cores, sum each chunk individually, and then accumulate the intermediate sums to get the final result.

    If you only care mainly about speed, you can sacrifice reproducibility and accumulate those intermediate sums in the order they're ready, which might vary from run to run, and therefore produce slightly different results. If you care mainly about reproducibility, then you have to sacrifice some speed by ensuring that you accumulate those intermediate sums in exactly the same order every time, which may leave the process waiting for a slower chunk even if a faster chunk is already ready. This is a simplistic example but the same principal applies for any computation parallelized on a GPU.

    So fundamentally speaking, there will always be a tradeoff between bitwise reproducibility and speed of computation. You've already discovered the primary flags for controlling this tradeoff (XLA_FLAGS="--xla_gpu_deterministic_ops" and JAX_DISABLE_MOST_FASTER_PATHS=1 ). Your question seems to be "can I somehow get both speed and strict bitwise reproducibility at once": the answer to that question is No.