I am using JAX, and I want to perform an operation like
@jax.jit
def fun(x, index):
x[:index] = other_fun(x[:index])
return x
This cannot be performed under jit
. Is there a way of doing this with jax.ops
or jax.lax
?
I thought of using jax.ops.index_update(x, idx, y)
but I cannot find a way of computing y
without incurring in the same problem again.
The previous answer by @rvinas using dynamic_slice
works well if your index is static, but you can also accomplish this with a dynamic index using jnp.where
. For example:
import jax
import jax.numpy as jnp
def other_fun(x):
return x + 1
@jax.jit
def fun(x, index):
mask = jnp.arange(x.shape[0]) < index
return jnp.where(mask, other_fun(x), x)
x = jnp.arange(5)
print(fun(x, 3))
# [1 2 3 3 4]