In autograd/numpy I could do:
q[q<0] = 0.0
How can I do the same thing in JAX?
I tried import numpy as onp
and using that to create arrays, but that doesn't seem to work.
JAX arrays are immutable, so in-place index assignment statements cannot work. Instead, jax provides the jax.ops
submodule, which provides functionality to create updated versions of arrays.
Here is an example of a numpy index assignment and the equivalent JAX index update:
import numpy as np
q = np.arange(-5, 5)
q[q < 0] = 0
print(q)
# [0 0 0 0 0 0 1 2 3 4]
import jax.numpy as jnp
q = jnp.arange(-5, 5)
q = q.at[q < 0].set(0) # NB: this does not modify the original array,
# but rather returns a modified copy.
print(q)
# [0 0 0 0 0 0 1 2 3 4]
Note that in op-by-op mode, the JAX version does create multiple copies of the array. However when used within a JIT compilation, XLA can often fuse such operations and avoid copying of data.