jax

How to get get the index position of a value with jit?


What would be a workaround of this code in jitted function?

j = indices.index(list(neighbor)) where neighbor is, for example, (2,3), indices = [[1,2], [4,5], ...]

Other alternatives like partial didn't work. One issue when using partial is that indices is not hashable so can't use partial function.


Solution

  • list.index is a Python function that will not work within JIT if the contents of the list are traced values. I would recommend converting your lists to arrays, and do something like this:

    import jax
    import jax.numpy as jnp
    
    indices = jnp.array([[1, 2], [4, 5], [3, 6], [2, 3], [5, 7]])
    neighbor = jnp.array([2, 3])
    
    @jax.jit
    def get_index(indices, neighbor):
      return jnp.where((indices == neighbor).all(-1), size=1)[0]
    
    idx = get_index(indices, neighbor)
    print(idx)
    # [3]