I'm reading JAX documentation on jax.local_devices and in it, it is written:
Like jax.devices(), but only returns devices local to a given process.
And in jax.devices() it is written:
Returns a list of all devices for a given backend.
I don't know what exactly are these local and non-local devices. Could you please elaborate on the difference between these?
This is discussed in JAX's documentation in Using JAX in multi-host and multi-process environments:
A process’s local devices are those that it can directly address and launch computations on. For example, on a GPU cluster, each host can only launch computations on the directly attached GPUs. On a Cloud TPU pod, each host can only launch computations on the 8 TPU cores attached directly to that host (see the Cloud TPU System Architecture documentation for more details). You can see a process’s local devices via
jax.local_devices()
.The global devices are the devices across all processes. A computation can span devices across processes and perform collective operations via the direct communication links between devices, as long as each process launches the computation on its local devices. You can see all available global devices via
jax.devices()
. A process’s local devices are always a subset of the global devices.