pythonindexingpytorchreduction

How to get unique elements and their firstly appeared indices of a pytorch tensor?


Assume a 2*X(always 2 rows) pytorch tensor:

A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
            [43., 33., 43., 76., 33., 76., 55., 55., 55.]])

torch.unique(A, dim=1) will return:

tensor([[ 1.,  2.,  2.,  3.,  3.,  4.],
        [43., 33., 43., 33., 76., 55.]])

But I also need the indices of every unique elements where they firstly appear in original input. In this case, indices should be like:

tensor([0, 1, 2, 3, 4, 6])

# Explanation
# A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
#             [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
#              (0)  (1)  (2)  (3)  (4)       (6) 

It's complex for me because the second row of tensor A may not be nicely sorted:

A = tensor([[ 1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.],
            [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
                             ^         ^

Is there a simple and efficient method to get the desired indices?

P.S. It may be useful that the first row of the tensor is always in ascending order.


Solution

  • One possible way to gain such indicies:

    unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
    _, ind_sorted = torch.sort(idx, stable=True)
    cum_sum = counts.cumsum(0)
    cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
    first_indicies = ind_sorted[cum_sum]
    

    For tensor A in snippet above:

    print(first_indicies)
    # tensor([0, 1, 2, 4, 3, 6])
    

    Note that unique in this case is equal to:

     tensor([[ 1.,  2.,  2.,  3.,  3.,  4.],
             [43., 33., 43., 33., 76., 55.]])