I'm new to JAX and I want to work with multiple GPUs. So far two GPUs (0 and 1) are visible to my JAX.
import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
When I create a NumPy object it will always be in GPU device 0 which I assume is the default one.
nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0
How can I send my variable nmp
to be stored in gpu:1
, the other GPU?
Use .device_put()
import jax
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
devices = jax.local_devices()
print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
nmp = jax.numpy.ones(4)
print(nmp.device()) # >>> gpu:0
nmp = jax.device_put(nmp, jax.devices()[1])
print(nmp.device()) # >>> gpu:1