I have a tensor inp
, which has a size of: torch.Size([4, 122, 161])
.
I also have a mask
with a size of: torch.Size([4, 122])
.
Each element in my mask
looks something like:
tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
1., 1., 1., 1., 1., 1., 1., 1., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
device='cuda:0', grad_fn=<SelectBackward>)
So I want to trim inp
to be reduced along the dimension=1 to only exist where the mask
has 1
. In the case shown, there are 23 1
s, so I want the size of inp
to be: torch.Size([4, 23, 161])
I think advanced indexing would work. (I assume every mask has equally 23 1s)
inp_trimmed = inp[mask.type(torch.bool)].reshape(4,23,161)