Here is a very simple computation in jax which errors out with complaints about static indices:
def get_slice(ar, k, I):
return ar[i:i+k]
vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))
arr = jnp.array([1, 2,3, 4, 5])
vec_get_slice(arr, 2, jnp.arange(3))
---------------------------------------------------------------------------
IndexError Traceback (most recent call last)
<ipython-input-32-6c60650ce6b7> in <cell line: 1>()
----> 1 vec_get_slice(arr, 2, jnp.arange(3))
[... skipping hidden 3 frame]
4 frames
<ipython-input-29-9528369725c2> in get_slice(ar, k, i)
1 def get_slice(ar, k, i):
----> 2 return ar[i:i+k]
/usr/local/lib/python3.10/dist-packages/jax/_src/array.py in __getitem__(self, idx)
346 return out
347
--> 348 return lax_numpy._rewriting_take(self, idx)
349
350 def __iter__(self):
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
4602
4603 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
-> 4604 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
4605 unique_indices, mode, fill_value)
4606
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted, unique_indices, mode, fill_value)
4611 unique_indices, mode, fill_value):
4612 idx = _merge_static_and_dynamic_indices(treedef, static_idx, dynamic_idx)
-> 4613 indexer = _index_to_gather(shape(arr), idx) # shared with _scatter_update
4614 y = arr
4615
/usr/local/lib/python3.10/dist-packages/jax/_src/numpy/lax_numpy.py in _index_to_gather(x_shape, idx, normalize_indices)
4854 "dynamic_update_slice (JAX does not support dynamically sized "
4855 "arrays within JIT compiled functions).")
-> 4856 raise IndexError(msg)
4857
4858 start, step, slice_size = _preprocess_slice(i, x_shape[x_axis])
Horrible error output below. I am obviously missing something simple, but what?
IndexError: Array slice indices must have static start/stop/step to be used with NumPy indexing syntax. Found slice(Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([0, 1, 2], dtype=int32)
batch_dim = 0, Traced<ShapedArray(int32[])>with<BatchTrace(level=1/0)> with
val = Array([2, 3, 4], dtype=int32)
batch_dim = 0, None). To index a statically sized array at a dynamic position, try lax.dynamic_slice/dynamic_update_slice (JAX does not support dynamically sized arrays within JIT compiled functions).
Indices passed to slices in JAX must be static. Values that are mapped over in vmap
are not static: because you're mapping over the start indices, your indices are not static and you see this error.
There is good news though: the size of your subarray is controlled by k
, which is unmapped in your code and therefore static; it's only the location of the slice (given by I
) that is dynamic. This is exactly the situation that jax.lax.dynamic_slice
was designed for, and so you can rewrite your code like this:
import jax
import jax.numpy as jnp
def get_slice(ar, k, I):
return jax.lax.dynamic_slice(ar, (I,), (k,))
vec_get_slice = jax.vmap(get_slice, in_axes=(None, None, 0))
arr = jnp.array([1, 2, 3, 4, 5])
vec_get_slice(arr, 2, jnp.arange(3))
# Array([[1, 2],
# [2, 3],
# [3, 4]], dtype=int32)