jax

Configuration options varying between jax installs?


I have a laptop I do work on for a program that includes jax, the program ends up getting run here on a small scale to test it, then it is sent off to a server for batch processing.

In the program I have set these flags for jax:

jax.config.update('jax_captured_constants_report_frames', -1)
jax.config.update('jax_captured_constants_warn_bytes', 128 * 1024 ** 2)

(as well as others but these are the relevant ones)

This runs fine on my laptop (using sharding to CPU parallelise), but when running on the server on GPU, I get an error message:

AttributeError: Unrecognized config option: jax_captured_constants_report_frames

(and the same for jax_captured_constants_warn_bytes if that were to run first)

Why is there this discrepancy? Can I use these flags some other way that is generalised between different jax installs?

pip list | grep jax, on laptop:

jax                       0.6.2
jaxlib                    0.6.2
jaxtyping                 0.3.2

on server:

jax                       0.6.0
jax-cuda12-pjrt           0.6.0
jax-cuda12-plugin         0.6.0
jaxlib                    0.6.0
jaxtyping                 0.3.2

EDIT: As a side note, what is the scope of jax flags? I have a jax initialisation function to set os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=" + str(cpu_count()) before the rest of the code runs, if I set jax.config.update(..., ...) options in here, will they hold in files called after it that also import jax? Or do I have to set them again? Is there a function to check the current value of these flags?


Solution

  • The jax_captured_constants_report_frames and jax_captured_constants_warn_bytes configurations were added in JAX version 0.6.1 (Relevant PR: https://github.com/jax-ml/jax/pull/28157) If you want to use them on your server, you'll have to update JAX to v0.6.1 or later.