I found a piece of JAX codes from few years ago.
import jax
import jax.random as rand
device_cpu = None
def do_on_cpu(f):
global device_cpu
if device_cpu is None:
device_cpu = jax.devices('cpu')[0]
def inner(*args, **kwargs):
with jax.default_device(device_cpu):
return f(*args, **kwargs)
return inner
seed2key = do_on_cpu(rand.PRNGKey)
seed2key.__doc__ = '''Same as `jax.random.PRNGKey`, but always produces the result on CPU.'''
and I call it with:
key = seed2key(42)
But it results in TypeError
:
TypeError Traceback (most recent call last)
Cell In[2], line 14
---> 14 key = seed2key(42)
File ~/bert-tokenizer-cantonese/lib/seed2key.py:12, in do_on_cpu.<locals>.inner(*args, **kwargs)
11 def inner(*args, **kwargs):
---> 12 with jax.default_device(device_cpu):
13 return f(*args, **kwargs)
TypeError: 'Device' object is not callable
I think the function has breaking changes after version upgrade.
Current versions:
(latest version at the moment of writing)
How can I change the codes to avoid the error? Thanks.
This code works fine in all recent versions of JAX: jax.default_device
is a configuration function designed to be used as a context manager.
I can reproduce the error you're seeing if I add this to the top of your script:
jax.default_device = jax.devices('cpu')[0] # wrong!
I suspect you inadvertently executed something similar to this at some point earlier in your notebook session. Try restarting your notebook runtime and rerunning just your valid code.