I have integer arrays of the type:
import numpy as np
seed_idx = np.asarray([[0, 1],
[1, 2],
[2, 3],
[3, 4]], dtype=np.int_)
target_idx = np.asarray([[2,9,4,1,8],
[9,7,6,2,4],
[1,0,0,4,9],
[7,1,2,3,8]], dtype=np.int_)
For each row of target_idx
, I want to select the elements whose indices are not the ones in seed_idx
. The resulting array should thus be:
[[4,1,8],
[9,2,4],
[1,0,9],
[7,1,2]]
In other words, I want to do something similar to np.take_along_axis(target_idx, seed_idx, axis=1)
, but excluding the indices instead of keeping them.
What is the most elegant way to do this? I find it surprisingly annoying to find something neat.
You can mask out the values you don't want with np.put_along_axis
and then index the others:
>>> np.put_along_axis(target_idx, seed_idx, -1, axis=1)
>>> target_idx[np.where(target_idx != -1)].reshape(len(target_idx), -1)
array([[4, 1, 8],
[9, 2, 4],
[1, 0, 9],
[7, 1, 2]])
If -1
is a valid value, use target_idx.min() - 1
.