I have a distance tensor
tensor([ 5, 10, 2, 3, 4], device='cuda:0')
And a indices tensor
tensor([ 0, 2, 3], device='cuda:0')
I want to find argmax of the distance tensor but only on the subset of indices specified by the indices tensor.
In this example, I would be looking at 0th, 2nd and 3rd elements of distance tensor (values 5, 2, 3) and returning the index 0 (the biggest value - 5 is on the 0th place in the distance tensor)
tensor([ 0], device='cuda:0')
Is something like this feasible without the use of for cycles? Thanks
Here an example. You can check that the maximum dist
value for the selected subset of items is at index zero, and the final output tensor contains value zero too. Note that as we are using 1D tensors, dim
argument in torch.index_select
is zero.
import torch
dist = torch.randn(5, 1)
#tensor([[ 0.3392],
# [ 0.4472],
# [ 0.1398],
# [-1.0379],
# [ 0.2950]])
idx = torch.tensor([0,2,3])
#tensor([0, 2, 3])
Just using max
function and tensor filtering:
max_val = torch.max(torch.index_select(dist, 0, idx)).item()
#0.33918169140815735
(dist == max_val).nonzero(as_tuple=True)[0]
#tensor([0])