pythonnumpynumpy-ndarraynp.argsort

find top_k element of numpy ndarray and ignore zero


Given a numpy ndarray like the following

x = [[4.,0.,2.,0.,8.],
     [1.,3.,0.,9.,5.],
     [0.,0.,4.,0.,1.]]

I want to find the indices of the top k (e.g. k=3) elements of each row, excluding 0, if possible. If there are less than k positive elements, then just return their indices (in a sorted way).

The result should look like (a list of array)

res = [[4, 0, 2],
       [3, 4, 1],
       [2, 4]]

or just one flatten array

res = [4,0,2,3,4,2,2,4]

I know argsort can find the indices of top k elements in a sorted order. But I am not sure how to filter out the 0.


Solution

  • You can use numpy.argsort with (-num) for getting index as descending. then use numpy.take_along_axis for getting values base index of 2D sorted. Because you want to ignore zero you can insert zero for other columns after three (as you mention in the question). At the end return value from the sorted values that is not zero.

    x = np.array([[4.,0.,2.,0.,8.],[1.,3.,0.,9.,5.],[0.,0.,4.,0.,1.]])
    idx_srt = np.argsort(-x)
    val_srt = np.take_along_axis(x, idx_srt, axis=-1)
    val_srt[:, 3:] = 0
    res = idx_srt[val_srt!=0]
    print(res)
    

    [4 0 2 3 4 1 2 4]