pythonpytorch

How can I trim a tensor based on a mask with PyTorch?


I have a tensor inp, which has a size of: torch.Size([4, 122, 161]).

I also have a mask with a size of: torch.Size([4, 122]).

Each element in my mask looks something like:

tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
        1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       device='cuda:0', grad_fn=<SelectBackward>)

So I want to trim inp to be reduced along the dimension=1 to only exist where the mask has 1. In the case shown, there are 23 1s, so I want the size of inp to be: torch.Size([4, 23, 161])


Solution

  • I think advanced indexing would work. (I assume every mask has equally 23 1s)

    inp_trimmed = inp[mask.type(torch.bool)].reshape(4,23,161)