multithreadingparallel-processingjax

Unable to set cpu device count for jax parallelisation?


I have been trying to generalise this jax program for solving on both CPU and GPU depending on the machine it's running on (essentially need cpu parallelisation to speed up testing versus gpu for production). I can get jax to parallelise on the GPU, but no matter what I do jax will not detect my cpu_count and thus cannot be sharded across cores (for context am running on 8 core, 16 thread laptop processor).

I found out that XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT had to be set before jax was initialised (was previously set in the if statement included in the code), but it is still not working. I also tried setting at the very start of my code (this is a snippet from the only file using jax itself, but some other files use jnp as a jax drop in for numpy).

Can anyone tell me why jax will not pick up on the flag? (Relevant code snippet and jupyter notebook output included below). Thanks.

Relevant code snippet:

from multiprocessing import cpu_count
core_count = cpu_count()

### THIS NEEDS TO BE SET BEFORE JAX IS INITIALISED IN ANY WAY, INCLUDING IMPORTING
# - XLA_FLAGS are read WHEN jax is IMPORTED

# you can see other ways of setting the environment variable that I've tried here

#jax.config.update('xla_force_host_platform_device_count', core_count)
#os.environ["XLA_FORCE_HOST_PLATFORM_DEVICE_COUNT"] = '16'#str(core_count)
#os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=' + str(core_count)
os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={cpu_count()}"

import jax

# defaults float data types to 64-bit instead of 32 for greater precision
jax.config.update('jax_enable_x64', True)
jax.config.update('jax_captured_constants_report_frames', -1)
jax.config.update('jax_captured_constants_warn_bytes', 128 * 1024 ** 2)
jax.config.update('jax_traceback_filtering', 'off')
# https://docs.jax.dev/en/latest/gpu_memory_allocation.html
#jax.config.update('xla_python_client_allocator', '\"platform\"')
# can't set via jax.config.update for some reason
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = '\"platform\"'

print("\nDefault jax backend:", jax.default_backend())

available_devices = jax.devices()
print(f"Available devices: {available_devices}")

running_device = xla_bridge.get_backend().platform
print("Running device:", running_device, end='')

if running_device == 'cpu':
    print(", with:", core_count, "cores.")

    from jax.sharding import PartitionSpec as P, NamedSharding

    # Create a Sharding object to distribute a value across devices:
    # Assume core_count is the no. of core devices available
    mesh = jax.make_mesh((core_count,), ('cols',))  # 1D mesh for columns

    # Example matrix shape (9, N), e.g., N = 1e7
    #x = jax.random.normal(jax.random.key(0), (9, Np))

    # Specify sharding: don't split axis 0 (rows), split axis 1 (columns) across devices
    # then apply sharding to produce a sharded array from the matrix input
    # and use jax.device_put to distribute it across devices:
    s0_sharded = jax.device_put(s0, NamedSharding(mesh, P(None, 'cols')))  # 'None' means don't shard axis 0

    print(s0_sharded.sharding)            # See the sharding spec
    print(s0_sharded.addressable_shards)  # Check each device's shard
    jax.debug.visualize_array_sharding(s0_sharded)

Output:

Default jax backend: cpu
Available devices: [CpuDevice(id=0)]
Running device: cpu, with: 16 cores.

...

relevant line of my code: --> 258 mesh = jax.make_mesh((core_count,), ('cols',))  # 1D mesh for columns
... jax backend trace
ValueError: Number of devices 1 must be >= the product of mesh_shape (16,)


Solution

  • I tried running your snippet and got a number of errors related to missing imports and undefined names (os is not defined, xla_bridge is not defined, s0 is undefined). This, along with the fact that you're running in Jupyter notebook, makes me think that you've already imported JAX in your runtime before running this cell.

    As mentioned in your code comments, the XLA device count must be set before JAX is imported in your runtime. You should try restarting the Jupyter kernel, then fix the missing imports and variables and rerun your cell as the first execution in your fresh runtime.

    Here's a simple recipe that should work to set the device count while asserting that you've not already imported JAX in another cell in your notebook:

    import os
    import sys
    
    assert "jax" not in sys.modules, "jax already imported: you must restart your runtime"
    os.environ['XLA_FLAGS'] = "--xla_force_host_platform_device_count=8"
    
    import jax
    print(jax.devices())
    # [CpuDevice(id=0), CpuDevice(id=1), CpuDevice(id=2), CpuDevice(id=3), CpuDevice(id=4), CpuDevice(id=5), CpuDevice(id=6), CpuDevice(id=7)]
    

    If running this results in an assertion error, then you'll have to restart your kernel/runtime before running it again.