pythonpytorchmatrix-indexing

How to sort a one hot tensor according to a tensor of indices


Given the below tensor:

tensor = torch.Tensor([[1., 0., 0., 0., 0.],
                       [0., 1., 0., 0., 0.],
                       [0., 0., 1., 0., 0.],
                       [0., 0., 0., 0., 1.],
                       [1., 0., 0., 0., 0.],
                       [1., 0., 0., 0., 0.],
                       [0., 0., 0., 1., 0.],
                       [0., 0., 0., 0., 1.]])

and below is the tensor containing the indices:

indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1])  

How can I sort tensor using the values inside of indices?

Trying with sorted gives the error:

TypeError: 'Tensor' object is not callable`.

While numpy.sort gives:

ValueError: Cannot specify order when the array has no fields.`


Solution

  • You can use the indices like this:

    tensor = torch.Tensor([[1., 0., 0., 0., 0.],
    [0., 1., 0., 0., 0.],
    [0., 0., 1., 0., 0.],
    [0., 0., 0., 0., 1.],
    [1., 0., 0., 0., 0.],
    [1., 0., 0., 0., 0.],
    [0., 0., 0., 1., 0.],
    [0., 0., 0., 0., 1.]])
    indices = torch.tensor([2, 6, 7, 5, 4, 0, 3, 1]) 
    sorted_tensor = tensor[indices]
    print(sorted_tensor)
    # output
    tensor([[0., 0., 1., 0., 0.],
            [0., 0., 0., 1., 0.],
            [0., 0., 0., 0., 1.],
            [1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [1., 0., 0., 0., 0.],
            [0., 0., 0., 0., 1.],
            [0., 1., 0., 0., 0.]])