
Applying numpy partition to a multi-dimensional array

I need to find the k smallest element within a np.array. In a simple case you would probably use np.partition.

import numpy as np

a = np.array([7, 4, 1, 0])
kth = 1
p = np.partition(a, kth)
print(f"Partitioned array: {p}")
print(f"kth's smallest element: {p[kth]}")

Partitioned array: [0 1 4 7]
kth's smallest element: 1

In my real use case, I need to apply the same technique to a multi-dimensional np.array. Let's take a 4-dim array as an example. The difficulty I am facing is that I need to apply different kths to each row of that array.
(Hint: array-4d and kths are coming from earlier operations.)

Here's the setup:

array_4d = np.array(
                [4, 1, np.nan, 20, 11, 12],
                [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
                [33, 4, 55, 26, 17, 18],
                [7, 8, 9, np.nan, 11, 12],
                [np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
                [13, 14, 15, 16, 17, 18],

kths = np.array(

print("4D array:")
print(f"Shape: {array_4d.shape}")

print("kths array:")
print(f"Shape: {kths.shape}")

4D array:
[[[[ 4.  1. nan 20. 11. 12.]]

  [[nan nan nan nan nan nan]]

  [[33.  4. 55. 26. 17. 18.]]]

 [[[ 7.  8.  9. nan 11. 12.]]

  [[nan nan nan nan nan nan]]

  [[13. 14. 15. 16. 17. 18.]]]]
Shape: (2, 3, 1, 6)

kths array:





Shape: (2, 3, 1, 1)

I need to apply the different kths (1, 2, 0, 0, 2, 1) to the respective row in the 4D array and find the respective smallest element at kth position.

The expected result should probably look like this:

array([[[[ 4.]],


        [[ 4.]]],

       [[[ 7.]],



EDIT: I am looking for a generalized solution. The input array could have any shape, with the exception that the second-to-last dimension (axis=-2) is always 1. For the kth array, the two last dimensions are always 1.


  • I would do the following:

    1. For simplified indexing, flatten everything except the axis of interest (if I understand you correctly, this is always the last axis), which produces an N×n-shaped array with N rows resulting from flattening and n columns representing the values along the axis of interest.
    2. Sort the values along the axis of interest.
    3. For the i-th of the N rows, pick the resulting element at index kths.ravel()[i], where kths.ravel() should be the N-element 1-d view of the given kths array.
    4. Reshape the result, minus the shape of the axis of interest.

    This could look as follows:

    import numpy as np
    # Given example
    nan = np.nan
    a = np.array([[[[ 4.,  1., nan, 20., 11., 12.]],
                   [[nan, nan, nan, nan, nan, nan]],
                   [[33.,  4., 55., 26., 17., 18.]]],
                  [[[ 7.,  8.,  9., nan, 11., 12.]],
                   [[nan, nan, nan, nan, nan, nan]],
                   [[13., 14., 15., 16., 17., 18.]]]])
    kths = np.asarray([1, 2, 0, 0, 2, 1]).reshape(a.shape[:-1])
    # Proposed approach
    sorted_a = np.sort(a.reshape(-1, a.shape[-1]), axis=-1)    # Step 1+2
    result = sorted_a[np.arange(len(sorted_a)), kths.ravel()]  # Step 3
    result = result.reshape(kths.shape)                        # Step 4

    Which produces:

    [[[ 4.]
      [ 4.]]
     [[ 7.]

    If you prefer, for your given example, a 2×3×1×1 (i.e. 4-d) result rather than the current 2×3×1 (i.e. 3-d) result, you can replace the last reshaping operation by

    result = result.reshape(*kths.shape, 1)

    Abandoned alternative

    As an alternative, I also tried producing sorted_a via

    sorted_a = np.partition(a.reshape(-1, a.shape[-1]), kth=range(np.max(kths) + 1), axis=-1)

    This was based on the idea that, at maximum, you need the entries up to the largest value in kths sorted. This, however, did not produce a speedup for me and, in fact, even slowed down the calculation significantly.