Assume a 2*X(always 2 rows) pytorch tensor:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
torch.unique(A, dim=1)
will return:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])
But I also need the indices of every unique elements where they firstly appear in original input. In this case, indices should be like:
tensor([0, 1, 2, 3, 4, 6])
# Explanation
# A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
# [43., 33., 43., 76., 33., 76., 55., 55., 55.]])
# (0) (1) (2) (3) (4) (6)
It's complex for me because the second row of tensor A
may not be nicely sorted:
A = tensor([[ 1., 2., 2., 3., 3., 3., 4., 4., 4.],
[43., 33., 43., 76., 33., 76., 55., 55., 55.]])
^ ^
Is there a simple and efficient method to get the desired indices?
P.S. It may be useful that the first row of the tensor is always in ascending order.
One possible way to gain such indicies:
unique, idx, counts = torch.unique(A, dim=1, sorted=True, return_inverse=True, return_counts=True)
_, ind_sorted = torch.sort(idx, stable=True)
cum_sum = counts.cumsum(0)
cum_sum = torch.cat((torch.tensor([0]), cum_sum[:-1]))
first_indicies = ind_sorted[cum_sum]
For tensor A
in snippet above:
print(first_indicies)
# tensor([0, 1, 2, 4, 3, 6])
Note that unique
in this case is equal to:
tensor([[ 1., 2., 2., 3., 3., 4.],
[43., 33., 43., 33., 76., 55.]])