I have a 3D array and a 2D array of indices. How can I select on the last axis?
import numpy as np
# example array
shape = (4,3,2)
x = np.random.uniform(0,1, shape)
# indices
idx = np.random.randint(0,shape[-1], shape[:-1])
Here is a loop that can give the desired result. But there should be an efficient vectorized way to do this.
result = np.zeros(shape[:-1])
for i in range(shape[0]):
for j in range(shape[1]):
result[i,j] = x[i,j,idx[i,j]]
A possible solution:
np.take_along_axis(x, np.expand_dims(idx, axis=-1), axis=-1).squeeze(axis=-1)
Alternatively,
i, j = np.meshgrid(np.arange(shape[0]), np.arange(shape[1]), indexing='ij')
x[i, j, idx]