pythonmemorycputpujax

Execute function specifically on CPU in Jax


I have a function that will instantiate a huge array and do other things. I am running my code on TPUs so my memory is limited.

How can I execute my function specifically on the CPU?

If I do:

y = jax.device_put(my_function(), device=jax.devices("cpu")[0])

I guess that my_function() is first executed on TPU and the result is put on CPU, which gives me memory error.

and using jax.config.update('jax_platform_name', 'cpu') at the beginning of my code seems to have no effect.

Also please note that I can't modify my_function()

Thanks!


Solution

  • To directly specify the device on which a function should be executed, use the device argument of jax.jit. For example (using a GPU runtime because it's the accelerator I have access to at the moment):

    import jax
    
    gpu_device = jax.devices('gpu')[0]
    cpu_device = jax.devices('cpu')[0]
    
    def my_function(x):
      return x.sum()
    
    x = jax.numpy.arange(10)
    
    x_gpu = jax.jit(my_function, device=gpu_device)(x)
    print(x_gpu.device())
    # gpu:0
    
    x_cpu = jax.jit(my_function, device=cpu_device)(x)
    print(x_cpu.device())
    # TFRT_CPU_0
    

    This can also be controlled with the jax.default_device decorator around the call-site:

    with jax.default_device(cpu_device):
      print(jax.jit(my_function)(x).device())
      # TFRT_CPU_0
    
    with jax.default_device(gpu_device):
      print(jax.jit(my_function)(x).device())
      # gpu:0