I've recently run into this problem in pytorch when working with 4D tensors which should be indexed with 3D tensors.
Let's say we have this 4D tensor:
possible_values.size()
torch.Size([2, 5, 5, 4])
where:
dim 1 = batch
dim 2 = x_axis
dim 3 = y_axis
dim 4 = possible values of coordinate (x_i,y_j)
we then have a 3D "indexing" tensor, which should be used to select the values of dim 4, based on an x and y coordinate:
coordinates.size()
torch.Size([2, 5, 2])
where:
dim 1 = batch
dim 2 = sequences of (x,y)
dim 3 = (x,y) coordinate
for example, coordinates
would look like
[ [ [1,5] [3,3] [2,4] [1,3] [2,3] ]
[ [1,5] [4,3] [2,1] [5,3] [5,3] ] ]
what we want to do is to select from a batch the possible values for the coordinates specified by coordinates
. So from the first batch we want to select the 4
values at coordinates [1, 5]
, [3, 3]
and so on.
I have looked some at index_select
and gather
, but can't get my head around it currently (or make it do roughly what I want).
Thanks.
Ok, let's start by removing the batch dimension:
possible_values[i,coordinates[i,:,0],coordinates[i,:,1],:] # [output is of shape [5,4]
The above gives the correct values for a single batch element. Now we need a way to broadcast this operation for all values of i (i.e. across the batch dimension).
possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:] # [output is of shape [2,2,5,4]
This is mostly correct but it is "over-broadcasted" (i.e. it returns the desired indices for each batch element, for EVERY batch element".) We now need to index just the main diagonal elements across the first 2 dimensions such that we get the desired indices for each batch element, for EACH batch element:
batch_size = possible_values.shape[0]
batch_idx = torch.arange(batch_size)
possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:][batch_size,batch_size,:,:] # output is of shape [2,5,4]
This solution leaves something to be desired in that it doesn't extend to arbitrarily many dimensions without modification (i.e. if you added a z-axis, you'd have to add an additional coordinates[:,:,2]
index to the block and so on.