pythonmachine-learningpytorchargmax

How to find the indexes of the first $n$ maximum values of a tensor?


I know that torch.argmax(x, dim = 0) returns the index of the first maximum value in x along dimension 0. But is there an efficient way to return the indexes of the first n maximum values? If there are duplicate values I also want the index of those among the n indexes.

As a concrete example, say x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1]). I would like a function

generalized_argmax(xI torch.tensor, n: int)

such that generalized_argmax(x, 4) returns [0, 2, 4, 5] in this example.


Solution

  • To acquire all you need you have to go over the whole tensor. The most efficient should therefore be to use argsort afterwards limited to n entries.

    >>> x=torch.tensor([2, 1, 4, 1, 4, 2, 1, 1])
    >>> x.argsort(dim=0, descending=True)[:n]
    [2, 4, 0, 5]
    

    Sort it again to get [0, 2, 4, 5] if you need the ascending order of indices.