
Slice array along axis with list of different indices

I have a 3-dimensional array/tensor of shape (a, b, c), and I have a list of length a of different indices, each in the range [0, b). I want to use the indices to get an array of size (a, c). Right now I do this with an ugly list comprehension

z = torch.stack([t_[b, :] for t_, b in zip(tensor, B)])

This is implemented in a forward pass for a neural network, so I really want to avoid a list comprehension. Is there any torch (or numpy) function that does what I want more efficient?

Also a small example:

tensor = [[[ 0,  1],
           [ 2,  3],
           [ 4,  5]],
          [[ 6,  7],
           [ 8,  9],
           [10, 11]],
          [[12, 13],
           [14, 15],
           [16, 17]],
          [[18, 19],
           [20, 21],
           [22, 23]]]  # shape: (4, 3, 2)
B = [0, 1, 2, 2]
output = [[ 0,  1],
          [ 8,  9],
          [16, 17],
          [22, 23]]  # shape (4, 2)

Background: I have time series data which has time windows of different lengths. I use torch's pack_padded_sequence (and reverse) to mask it, but I have to get the output of the LSTM at the time step before the masking starts, because then the output of the network gets unusable. In the example, I would have 4 time steps with length 0, 1, 2, 2 each with 2 features.


  • Use advanced indexing. To get the desired output, we need the corresponding indices for the first axis, which is created using torch.arange() below:

    output = tensor[torch.arange(len(B)), B]

    or using numpy

    output = tensor[np.arange(len(B)), B]

    both produce:

    tensor([[ 0,  1],
            [ 8,  9],
            [16, 17],
            [22, 23]])

    Full code using example:

    import torch
    tensor = torch.tensor([
        [[ 0,  1],
         [ 2,  3],
         [ 4,  5]],
        [[ 6,  7],
         [ 8,  9],
         [10, 11]],
        [[12, 13],
         [14, 15],
         [16, 17]],
        [[18, 19],
         [20, 21],
         [22, 23]]])
    B = [0, 1, 2, 2]
    output = tensor[torch.arange(len(B)), B]