I have a Matrix of indices I e.g.
I = np.array([[1, 0, 2], [2, 1, 0]])
The index at i-th row selects an element from another Matrix M in the i-th row.
So having M e.g.
M = np.array([[6, 7, 8], [9, 10, 11])
M[I] should select:
[[7, 6, 8], [11, 10, 9]]
I could have:
I1 = np.repeat(np.arange(0, I.shape[0]), I.shape[1])
I2 = np.ravel(I)
Result = M[I1, I2].reshape(I.shape)
but this looks very complicated and I am looking for a more elegant solution. Preferably without flattening and reshaping.
In the example I used numpy, but I am actually using jax. So if there is a more efficient solution in jax, feel free to share.
In [108]: I = np.array([[1, 0, 2], [2, 1, 0]])
...: M = np.array([[6, 7, 8], [9, 10, 11]])
...:
...: I,M
I had to add a ']' to M.
Out[108]:
(array([[1, 0, 2],
[2, 1, 0]]),
array([[ 6, 7, 8],
[ 9, 10, 11]]))
Advanced indexing with broadcasting
:
In [110]: M[np.arange(2)[:,None],I]
Out[110]:
array([[ 7, 6, 8],
[11, 10, 9]])
THe first index has shape (2,1) which pairs with the (2,3) shape of I
to select a (2,3) block of values.