pythonautogradjax

Conditional update in JAX?


In autograd/numpy I could do:

q[q<0] = 0.0

How can I do the same thing in JAX?

I tried import numpy as onp and using that to create arrays, but that doesn't seem to work.


Solution

  • JAX arrays are immutable, so in-place index assignment statements cannot work. Instead, jax provides the jax.ops submodule, which provides functionality to create updated versions of arrays.

    Here is an example of a numpy index assignment and the equivalent JAX index update:

    import numpy as np
    q = np.arange(-5, 5)
    q[q < 0] = 0
    print(q)
    # [0 0 0 0 0 0 1 2 3 4]
    
    import jax.numpy as jnp
    q = jnp.arange(-5, 5)
    q = q.at[q < 0].set(0)  # NB: this does not modify the original array,
                            # but rather returns a modified copy.
    print(q)
    # [0 0 0 0 0 0 1 2 3 4]
    

    Note that in op-by-op mode, the JAX version does create multiple copies of the array. However when used within a JIT compilation, XLA can often fuse such operations and avoid copying of data.