pythonnumpy

Efficient masked argsort in Numpy


I have a numpy array such as this one:

arr = np.array([
    [1, 2, 3],
    [4, -5, 6],
    [-1, -1, -1]
])

And I would like to argsort it, but with a arr <= 0 mask. The output should be:

array([[0, 1, 2],
       [0, 2],       # (Note that the indices are still relative to original un-masked array)
       []])

However, the output I get using np.ma.argsort() is:

array([[0, 1, 2],
       [0, 2, 1],
       [0, 1, 2]])

The approach needs to be very efficient because the real array has millions of columns. I am thinking this needs to be a synthesis of a few operations, but I don't know which ones.


Solution

  • The np.where approach:

    Input array

    arr = np.array([
        [1, 2, 3],
        [4, -5, 6],
        [-1, -1, -1]
    ])
    

    Mask of valid elements

    mask = arr > 0
    

    Preallocate result as an object array to hold variable-length indices

    result = np.empty(arr.shape[0], dtype=object)
    

    Efficient masked argsort for each row

    for i in range(arr.shape[0]):
        valid_indices = np.where(mask[i])[0]  # Get indices of valid (masked) elements
        result[i] = valid_indices[np.argsort(arr[i, valid_indices])]  # Sort valid indices by their values
    

    Output:

    [array([0, 1, 2]) array([0, 2]) array([], dtype=int64)]
    

    The np.flatnonzero approach:

    A more optimised approach using vectorised operations:

    def optimized_masked_argsort(arr, mask):
        result = np.empty(arr.shape[0], dtype=object)
        for i in range(arr.shape[0]):
            row = arr[i]
            valid_indices = np.flatnonzero(mask[i])  # Faster than np.where(mask[i])[0]
            valid_values = row[valid_indices]
            sorted_order = np.argsort(valid_values)
            result[i] = valid_indices[sorted_order]
        return result
    

    Comparison:

    Timings for given example:
    np.where Time: 0.000034 seconds
    np.flatnonzero Time: 0.000017 seconds
    
    Timings for larger array (1000 rows):
    np.where Time: 0.001856 seconds
    np.flatnonzero Time: 0.001754 seconds
    

    I tried a few other methods but they fell short in efficiency.