I have a laptop I do work on for a program that includes jax, the program ends up getting run here on a small scale to test it, then it is sent off to a server for batch processing.
In the program I have set these flags for jax:
jax.config.update('jax_captured_constants_report_frames', -1)
jax.config.update('jax_captured_constants_warn_bytes', 128 * 1024 ** 2)
(as well as others but these are the relevant ones)
This runs fine on my laptop (using sharding to CPU parallelise), but when running on the server on GPU, I get an error message:
AttributeError: Unrecognized config option: jax_captured_constants_report_frames
(and the same for jax_captured_constants_warn_bytes if that were to run first)
Why is there this discrepancy? Can I use these flags some other way that is generalised between different jax installs?
pip list | grep jax, on laptop:
jax 0.6.2
jaxlib 0.6.2
jaxtyping 0.3.2
on server:
jax 0.6.0
jax-cuda12-pjrt 0.6.0
jax-cuda12-plugin 0.6.0
jaxlib 0.6.0
jaxtyping 0.3.2
EDIT:
As a side note, what is the scope of jax flags?
I have a jax initialisation function to set os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=" + str(cpu_count())
before the rest of the code runs, if I set jax.config.update(..., ...) options in here, will they hold in files called after it that also import jax? Or do I have to set them again? Is there a function to check the current value of these flags?
The jax_captured_constants_report_frames
and jax_captured_constants_warn_bytes
configurations were added in JAX version 0.6.1 (Relevant PR: https://github.com/jax-ml/jax/pull/28157) If you want to use them on your server, you'll have to update JAX to v0.6.1 or later.