pythonjax

Modifying multiple dimensions of Jax array simultaneously


When using the jax_array.at[idx] function, I wish to be able to set values at both a set of specified rows and columns within the jax_array to another jax_array containing values in the same shape. For example, given a 5x5 jax array, I might want to set the values, jax_array.at[[0,3],:][:,[1,2]] to some 2x2 array of values. However, I am coming across an issue where the _IndexUpdateRef' object is not subscriptable. I understand the idea of the error (and I get a similar one when using 2 chained .at[]s), but I want to know if there is anyway to achieve the desired functionality within 1 line.


Solution

  • JAX follows the indexing semantics of NumPy, and NumPy's indexing semantics allow you to do this via broadcasted arrays of indices (this is discussed in Integer array indexing in the NumPy docs).

    So for example, you could do something like this:

    import jax.numpy as jnp
    
    x = jnp.zeros((4, 6), dtype=int)
    y = jnp.array([[1, 2],
                   [3, 4]])
    i = jnp.array([0, 3])
    j = jnp.array([1, 2])
    
    # reshape indices so they broadcast 
    i = i[:, jnp.newaxis]
    j = j[jnp.newaxis, :]
    
    x = x.at[i, j].set(y)
    print(x)
    
    [[0 1 2 0 0 0]
     [0 0 0 0 0 0]
     [0 0 0 0 0 0]
     [0 3 4 0 0 0]]
    

    Here the i index has shape (2, 1), and the j index has shape (1, 2), and via broadcasting rules they index a 2x2 noncontiguous subgrid of the array x, which you can then set to the contents of y in a single statement.