pythonnumpyjax

Apply function only on slice of array under jit


I am using JAX, and I want to perform an operation like

@jax.jit
def fun(x, index):
    x[:index] = other_fun(x[:index])
    return x

This cannot be performed under jit. Is there a way of doing this with jax.ops or jax.lax? I thought of using jax.ops.index_update(x, idx, y) but I cannot find a way of computing y without incurring in the same problem again.


Solution

  • The previous answer by @rvinas using dynamic_slice works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where. For example:

    import jax
    import jax.numpy as jnp
    
    def other_fun(x):
        return x + 1
    
    @jax.jit
    def fun(x, index):
      mask = jnp.arange(x.shape[0]) < index
      return jnp.where(mask, other_fun(x), x)
    
    x = jnp.arange(5)
    print(fun(x, 3))
    # [1 2 3 3 4]