I have a 2D tensor and I want to get the indices of the top k values. I know about pytorch's topk function. The problem with pytorch's topk function is, it computes the topk values over some dimension. I want to get topk values over both dimensions.
For example for the following tensor
a = torch.tensor([[4, 9, 7, 4, 0],
[8, 1, 3, 1, 0],
[9, 8, 4, 4, 8],
[0, 9, 4, 7, 8],
[8, 8, 0, 1, 4]])
pytorch's topk function will give me the following.
values, indices = torch.topk(a, 3)
print(indices)
# tensor([[1, 2, 0],
# [0, 2, 1],
# [0, 1, 4],
# [1, 4, 3],
# [1, 0, 4]])
But I want to get the following
tensor([[0, 1],
[2, 0],
[3, 1]])
This is the indices of 9 in the 2D tensor.
Is there any approach to achieve this using pytorch?
v, i = torch.topk(a.flatten(), 3)
print (np.array(np.unravel_index(i.numpy(), a.shape)).T)
Output:
[[3 1]
[2 0]
[0 1]]
unravel_index