pythonnumpypytorchadvanced-indexing

In Pytorch how to slice tensor across multiple dims with BoolTensor masks?


I want to use BoolTensor indices to slice a multidimensional tensor in Pytorch. I expect for the indexed tensor, the parts where the indices are true are kept, while the parts where the indices are false are sliced out.

My code is like

import torch
a = torch.zeros((5, 50, 5, 50))

tr_indices = torch.zeros((50), dtype=torch.bool)
tr_indices[1:50:2] = 1
val_indices = ~tr_indices

print(a[:, tr_indices].shape)
print(a[:, tr_indices, :, val_indices].shape)

I expect a[:, tr_indices, :, val_indices] to be of shape [5, 25, 5, 25], however it returns [25, 5, 5]. The result is

torch.Size([5, 25, 5, 50])
torch.Size([25, 5, 5])

I'm very confused. Can anyone explain why?


Solution

  • PyTorch inherits its advanced indexing behaviour from Numpy. Slicing twice like so should achieve your desired output:

    a[:, tr_indices][..., val_indices]