I need to find the k
smallest element within a np.array
. In a simple case you would probably use np.partition
.
import numpy as np
a = np.array([7, 4, 1, 0])
kth = 1
p = np.partition(a, kth)
print(f"Partitioned array: {p}")
print(f"kth's smallest element: {p[kth]}")
Partitioned array: [0 1 4 7]
kth's smallest element: 1
In my real use case, I need to apply the same technique to a multi-dimensional np.array
. Let's take a 4-dim array as an example. The difficulty I am facing is that I need to apply different kth
s to each row of that array.
(Hint: array-4d
and kths
are coming from earlier operations.)
Here's the setup:
array_4d = np.array(
[
[
[
[4, 1, np.nan, 20, 11, 12],
],
[
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
],
[
[33, 4, 55, 26, 17, 18],
],
],
[
[
[7, 8, 9, np.nan, 11, 12],
],
[
[np.nan, np.nan, np.nan, np.nan, np.nan, np.nan],
],
[
[13, 14, 15, 16, 17, 18],
],
],
]
)
kths = np.array(
[
[
[[1]],
[[2]],
[[0]],
],
[
[[0]],
[[2]],
[[1]],
],
]
)
print("4D array:")
print(array_4d)
print(f"Shape: {array_4d.shape}")
print("kths array:")
print(kths)
print(f"Shape: {kths.shape}")
4D array:
[[[[ 4. 1. nan 20. 11. 12.]]
[[nan nan nan nan nan nan]]
[[33. 4. 55. 26. 17. 18.]]]
[[[ 7. 8. 9. nan 11. 12.]]
[[nan nan nan nan nan nan]]
[[13. 14. 15. 16. 17. 18.]]]]
Shape: (2, 3, 1, 6)
kths array:
[[[[1]]
[[2]]
[[0]]]
[[[0]]
[[2]]
[[1]]]]
Shape: (2, 3, 1, 1)
I need to apply the different kth
s (1, 2, 0, 0, 2, 1) to the respective row in the 4D array and find the respective smallest element at kth
position.
The expected result should probably look like this:
array([[[[ 4.]],
[[nan]],
[[ 4.]]],
[[[ 7.]],
[[nan]],
[[14.]]]])
EDIT: I am looking for a generalized solution. The input array could have any shape, with the exception that the second-to-last dimension (axis=-2
) is always 1
. For the kth
array, the two last dimensions are always 1
.
I would do the following:
kths.ravel()[i]
, where kths.ravel()
should be the N-element 1-d view of the given kths
array.This could look as follows:
import numpy as np
# Given example
nan = np.nan
a = np.array([[[[ 4., 1., nan, 20., 11., 12.]],
[[nan, nan, nan, nan, nan, nan]],
[[33., 4., 55., 26., 17., 18.]]],
[[[ 7., 8., 9., nan, 11., 12.]],
[[nan, nan, nan, nan, nan, nan]],
[[13., 14., 15., 16., 17., 18.]]]])
kths = np.asarray([1, 2, 0, 0, 2, 1]).reshape(a.shape[:-1])
# Proposed approach
sorted_a = np.sort(a.reshape(-1, a.shape[-1]), axis=-1) # Step 1+2
result = sorted_a[np.arange(len(sorted_a)), kths.ravel()] # Step 3
result = result.reshape(kths.shape) # Step 4
print(result)
Which produces:
[[[ 4.]
[nan]
[ 4.]]
[[ 7.]
[nan]
[14.]]]
If you prefer, for your given example, a 2×3×1×1 (i.e. 4-d) result rather than the current 2×3×1 (i.e. 3-d) result, you can replace the last reshaping operation by
result = result.reshape(*kths.shape, 1)
As an alternative, I also tried producing sorted_a
via
sorted_a = np.partition(a.reshape(-1, a.shape[-1]), kth=range(np.max(kths) + 1), axis=-1)
This was based on the idea that, at maximum, you need the entries up to the largest value in kths
sorted. This, however, did not produce a speedup for me and, in fact, even slowed down the calculation significantly.