What would be a workaround of this code in jitted function?
j = indices.index(list(neighbor))
where neighbor
is, for example, (2,3)
, indices = [[1,2], [4,5], ...]
Other alternatives like partial
didn't work. One issue when using partial
is that indices is not hashable so can't use partial function.
list.index
is a Python function that will not work within JIT if the contents of the list are traced values. I would recommend converting your lists to arrays, and do something like this:
import jax
import jax.numpy as jnp
indices = jnp.array([[1, 2], [4, 5], [3, 6], [2, 3], [5, 7]])
neighbor = jnp.array([2, 3])
@jax.jit
def get_index(indices, neighbor):
return jnp.where((indices == neighbor).all(-1), size=1)[0]
idx = get_index(indices, neighbor)
print(idx)
# [3]