pythondeep-learningpytorchartificial-intelligencejax

DIfference in variable values in jax non-jit runtime and jit transformed runtime


I have a deep learning mode which I am running in the jit transformed manner by:

my_function_checked = checkify.checkify(model.apply)
    model_jitted = jax.jit(my_function_checked)
    err, pred = model_jitted({"params": params}, batch, training=training, rng=rng)
    err.throw()

The code is compiling fine, but now I want to debug the intermediate values after every few steps, save the arrays, and then compare them with pytorch tensors. For this, I need to repeatedly save the arrays. The easiest way to do this is to use any IDE's inbuilt debugger and evaluate the save expression after every few steps. But jax.jit transformed code doesn't allow external debuggers. But, I can do this after disabling the jit. Should I be expecting any discrepancies between the two runs? Can I assume that the values in jit and non-jit runs will remain same?


Solution

  • In general when comparing the same JAX operation with and without JIT, you should expect equivalence up to typical floating point rounding errors, but you should not expect bitwise equivalence, as the compiler may fuse operations in a way that leads to differing float error accumulation.