pythonjax

JAX TypeError: 'Device' object is not callable


I found a piece of JAX codes from few years ago.

import jax
import jax.random as rand

device_cpu = None

def do_on_cpu(f):
    global device_cpu
    if device_cpu is None:
        device_cpu = jax.devices('cpu')[0]

    def inner(*args, **kwargs):
        with jax.default_device(device_cpu):
            return f(*args, **kwargs)
    return inner

seed2key = do_on_cpu(rand.PRNGKey)
seed2key.__doc__ = '''Same as `jax.random.PRNGKey`, but always produces the result on CPU.'''

and I call it with:

key = seed2key(42)

But it results in TypeError:

TypeError                                 Traceback (most recent call last)
Cell In[2], line 14
---> 14 key = seed2key(42)

File ~/bert-tokenizer-cantonese/lib/seed2key.py:12, in do_on_cpu.<locals>.inner(*args, **kwargs)
     11 def inner(*args, **kwargs):
---> 12     with jax.default_device(device_cpu):
     13         return f(*args, **kwargs)

TypeError: 'Device' object is not callable

I think the function has breaking changes after version upgrade.

Current versions:

(latest version at the moment of writing)

How can I change the codes to avoid the error? Thanks.


Solution

  • This code works fine in all recent versions of JAX: jax.default_device is a configuration function designed to be used as a context manager.

    I can reproduce the error you're seeing if I add this to the top of your script:

    jax.default_device = jax.devices('cpu')[0]  # wrong!
    

    I suspect you inadvertently executed something similar to this at some point earlier in your notebook session. Try restarting your notebook runtime and rerunning just your valid code.