pythondeep-learninggpumulti-gpujax

How to specify or set a variable to a GPU device


I'm new to JAX and I want to work with multiple GPUs. So far two GPUs (0 and 1) are visible to my JAX.

import jax
import os

os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
print(jax.local_devices())
>>>
# prints: [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]

When I create a NumPy object it will always be in GPU device 0 which I assume is the default one.

nmp = jax.numpy.ones(4)
print(nmp.device())
>>>
# Prints: gpu:0

How can I send my variable nmp to be stored in gpu:1, the other GPU?


Solution

  • Use .device_put()

    import jax
    import os
    
    os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
    devices = jax.local_devices()
    print(devices) # >>> [GpuDevice(id=0, process_index=0), GpuDevice(id=1, process_index=0)]
    
    nmp = jax.numpy.ones(4)
    print(nmp.device()) # >>> gpu:0
    
    nmp = jax.device_put(nmp, jax.devices()[1])
    print(nmp.device()) # >>> gpu:1