pythonarraysnumpynumpy-slicingnumpy-indexing

How to indexing multi-dimensional arrays given by indices in a certain axis?


Let's say I have a 4d array A with shape (D0, D1, D2, D3). I have a 1d array B with shape (D0,), which includes the indices I need at axis 2.

The trivial way to implement what I need:

output_lis = []
for i in range(D0):
    output_lis.append(A[i, :, B[i], :])
#output = np.concatenate(output_lis, axis=0) #it is wrong to use concatenate. Thanks to @Mad Physicist. Instead, using stack.
output = np.stack(output_lis, axis=0) #shape: [D0, D1, D3]

So, my question is how to implement it with numpy API in a fast way?


Solution

  • Use fancy indexing to step along two dimensions in lockstep. In this case, arange provides the sequence i, while B provides the sequence B[i]:

    A[np.arange(D0), :, B, :]
    

    The shape of this array is indeed (D0, D1, D3), unlike the shape of your for loop result.

    To get the same result from your example, use stack (which adds a new axis), rather than concatenate (which uses an existing axis):

    output = np.stack(output_lis, axis=0)