I know that torch.argmax(x, dim = 0)
returns the index of the first maximum value in x
along dimension 0
. But is there an efficient way to return the indexes of the first n
maximum values? If there are duplicate values I also want the index of those among the n
indexes.
As a concrete example, say x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
. I would like a function
generalized_argmax(xI torch.tensor, n: int)
such that
generalized_argmax(x, 4)
returns [0, 2, 4, 5]
in this example.
To acquire all you need you have to go over the whole tensor. The most efficient should therefore be to use argsort
afterwards limited to n
entries.
>>> x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
>>> x.argsort(dim=0, descending=True)[:n]
[2, 4, 0, 5]
Sort it again to get [0, 2, 4, 5]
if you need the ascending order of indices.