pythonpytorchspatial-data

Selecting values from a 4d tensor with a 3d tensor


I've recently run into this problem in pytorch when working with 4D tensors which should be indexed with 3D tensors.

Let's say we have this 4D tensor:

possible_values.size()
torch.Size([2, 5, 5, 4])

where:

dim 1 = batch
dim 2 = x_axis
dim 3 = y_axis
dim 4 = possible values of coordinate (x_i,y_j)

we then have a 3D "indexing" tensor, which should be used to select the values of dim 4, based on an x and y coordinate:

coordinates.size()
torch.Size([2, 5, 2])

where:

dim 1 = batch
dim 2 = sequences of (x,y) 
dim 3 = (x,y) coordinate

for example, coordinates would look like

[ [ [1,5] [3,3] [2,4] [1,3] [2,3] ]
  [ [1,5] [4,3] [2,1] [5,3] [5,3] ] ]

what we want to do is to select from a batch the possible values for the coordinates specified by coordinates. So from the first batch we want to select the 4 values at coordinates [1, 5], [3, 3] and so on.

I have looked some at index_select and gather, but can't get my head around it currently (or make it do roughly what I want).

Thanks.


Solution

  • Ok, let's start by removing the batch dimension:

    possible_values[i,coordinates[i,:,0],coordinates[i,:,1],:]  # [output is of shape [5,4]
    

    The above gives the correct values for a single batch element. Now we need a way to broadcast this operation for all values of i (i.e. across the batch dimension).

    possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:]  # [output is of shape [2,2,5,4]
    

    This is mostly correct but it is "over-broadcasted" (i.e. it returns the desired indices for each batch element, for EVERY batch element".) We now need to index just the main diagonal elements across the first 2 dimensions such that we get the desired indices for each batch element, for EACH batch element:

    batch_size = possible_values.shape[0]
    batch_idx = torch.arange(batch_size)
    possible_values[:,coordinates[:,:,0],coordinates[:,:,1],:][batch_size,batch_size,:,:]   # output is of shape [2,5,4]
    

    This solution leaves something to be desired in that it doesn't extend to arbitrarily many dimensions without modification (i.e. if you added a z-axis, you'd have to add an additional coordinates[:,:,2] index to the block and so on.