Given a numpy ndarray like the following
x = [[4.,0.,2.,0.,8.],
[1.,3.,0.,9.,5.],
[0.,0.,4.,0.,1.]]
I want to find the indices of the top k (e.g. k=3) elements of each row, excluding 0, if possible. If there are less than k positive elements, then just return their indices (in a sorted way).
The result should look like (a list of array)
res = [[4, 0, 2],
[3, 4, 1],
[2, 4]]
or just one flatten array
res = [4,0,2,3,4,2,2,4]
I know argsort
can find the indices of top k elements in a sorted order. But I am not sure how to filter out the 0.
You can use numpy.argsort
with (-num)
for getting index as descending. then use numpy.take_along_axis
for getting values base index of 2D sorted.
Because you want to ignore zero
you can insert zero for other columns after three
(as you mention in the question). At the end return value from the sorted values that is not zero.
x = np.array([[4.,0.,2.,0.,8.],[1.,3.,0.,9.,5.],[0.,0.,4.,0.,1.]])
idx_srt = np.argsort(-x)
val_srt = np.take_along_axis(x, idx_srt, axis=-1)
val_srt[:, 3:] = 0
res = idx_srt[val_srt!=0]
print(res)
[4 0 2 3 4 1 2 4]