I am rewriting some code from pure Python to JAX. I have gotten to the point where in my old code, I was using Python's multiprocessing module to parallelize the evaluation of a function over all of the CPU cores in a single node as follows:
# start pool process
pool = multiprocessing.Pool(processes=10) # if node has 10 CPU cores, start 10 processes
# use pool.map to evaluate function(input) for each input in parallel
# suppose len(inputs) is very large and 10 inputs are processed in parallel at a time
# store the results in a list called out
out = pool.map(function,inputs)
# close pool processes to free memory
pool.close()
pool.join()
I know that JAX has vmap and pmap, but I don't understand if either of those are a drop-in replacement for how I'm using multiprocessing.pool.map above.
vmap(function,in_axes=0)(inputs)
distributing to all available CPU cores or what?pmap(function,in_axes=0)(inputs)
different from vmap and multiprocessing.pool.map?pmap(function,in_axes=0)(inputs)
I get an error -- ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1) -- what does this mean?
- Is
vmap(function,in_axes=0)(inputs)
distributing to all available CPU cores or what?
No, vmap
has nothing to do with parallelization. It is a vectorizing transformation, not a parallelizing transformation. In the course of normal operation, JAX may use multiple cores via XLA, so vmapped operations may also do this. But there's no explicit parallelization in vmap
.
- How is
pmap(function,in_axes=0)(inputs)
different fromvmap
andmultiprocessing.pool.map
?
pmap
parallelizes over multiple XLA devices. vmap
does not parallelize, but rather vectorizes on a single device. multiprocessing
parallelizes over multiple Python processes.
- Is my usage of multiprocessing.pool.map above an example of a "single-program, multiple-data (SPMD)" code that pmap is meant for?
Yes, it could be described as SPMD across multiple python processes.
- When I actually do
pmap(function,in_axes=0)(inputs)
I get an error --ValueError: compiling computation that requires 10 logical devices, but only 1 XLA devices are available (num_replicas=10, num_partitions=1)
-- what does this mean?
pmap
parallelizes over multiple XLA devices, and you have configured only a single XLA device, so the requested operation is not possible.
- Finally, my use case is very simple: I merely want to use some/all of the CPU cores on a single node (e.g., all 10 CPU cores on my Macbook). But I have heard about nesting pmap(vmap) -- is this used to parallelize over the cores of multiple connected nodes (say on a supercomputer)? This would be more akin to mpi4py rather than multiprocessing (the latter is restricted to a single node).
Yes, I believe that pmap
can be used to compute on multiple CPU cores. Whether it's nested with vmap
is irrelevant. See JAX pmap with multi-core CPU.
Note also that jax.pmap
is deprecated in favor of the newer jax.shard_map
, which is a much more flexible transformation for multi-device/multi-host computation. There's some info here: https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization.html and https://jax.readthedocs.io/en/latest/jep/14273-shard-map.html