I would like to apply a sort operation, row per row, only keeping values above a given threshold.
For this, I see I can use a masked array to apply the threshold.
However, argsort
keeps considering masked values (below the threshold) and replace them with a fill_value
.
However, I simply don't want any result if the value has been replaced with a NaN.
a = np.array([[0.522235,0.128270,0.708973],
[0.994557,0.844426,0.366608],
[0.986669,0.143659,0.395891],
[0.291339,0.421843,0.278869],
[0.250303,0.861475,0.904534],
[0.973436,0.360466,0.751913]])
threshold = 0.5
m_a = np.ma.masked_less_equal(a, threshold)
argsorted = m_a.argsort(-1)
This gives me:
array([[0, 2, 1],
[1, 0, 2],
[0, 1, 2],
[0, 1, 2],
[1, 2, 0],
[2, 0, 1]])
But I would like to get:
array([[0, NaN, 1],
[1, 0, NaN],
[0, NaN, NaN],
[NaN, NaN, NaN],
[NaN, 0, 1],
[ 1, NaN, 0]])
Any idea to get to this result?
Thanks for your help! Bests,
We can add one more argsort
for an easier way to get to our desired output -
sidx = argsorted.argsort(1)
mask = sidx >= (a.shape[1]-m_a.mask.sum(1,keepdims=True))
out = np.where(mask,np.nan,sidx)
We can also start from scratch to avoid masked-arrays
-
def thresholded_argsort(a, threshold):
m = a<threshold
ac = a.copy()
ac[m] = ac.max()+1
sidx = ac.argsort(1).argsort(1)
mask = sidx>=(ac.shape[1]-m.sum(1,keepdims=True))
return np.where(mask,np.nan,sidx)
Sample run -
In [46]: a
Out[46]:
array([[0.522235, 0.12827 , 0.708973],
[0.994557, 0.844426, 0.366608],
[0.986669, 0.143659, 0.395891],
[0.291339, 0.421843, 0.278869],
[0.250303, 0.861475, 0.904534],
[0.973436, 0.360466, 0.751913]])
In [47]: thresholded_argsort(a, threshold=0.5)
Out[47]:
array([[ 0., nan, 1.],
[ 1., 0., nan],
[ 0., nan, nan],
[nan, nan, nan],
[nan, 0., 1.],
[ 1., nan, 0.]])
Note : We can avoid the additional argsort with array-assignment
for performance using argsort_unique
. So, for 2D
arrays along second axis, it would be -
def argsort_unique2D(idx):
m,n = idx.shape
idx_out = np.empty((m,n),dtype=int)
np.put_along_axis(idx_out, idx, np.arange(n), axis=1)
return idx_out
So, argsorted.argsort(1)
could be replaced by argsort_unique2D(argsorted)
, while ac.argsort(1).argsort(1)
with argsort_unique2D(ac.argsort(1))
in the earlier posted solutions.