pythonpython-3.xnumpynumpy-ndarraynumpy-slicing

Select n'th element along m'th axis on numpy array


My goal is to use an numpy array containing integers and use this to select the element in the corresponding row of a 2D array. The 2D array is monotonically increasing along the 1st axis and I am trying to select the element that equals or exceeds some threshold for each row.

I have a 2D dimensional array of size Nx10 where the values on the 2nd axis are monotonically increasing. E.g: sizes = np.array([[1, 3, 6, 6, 6, 7, 8, 8, 10, 10], [2, 3, 3, 7, 7, 7, 9, 9, 10, 11], [2, 3, 3, 5, 5, 6, 9, 9, 10, 11], [2, 3, 3, 9, 9, 9, 9, 9, 10, 11]])

The goal is to find the index of the element that equals or exceeds 5, done by the following code: threshold_indeces = np.argmax(sizes >= 5, axis=1) Now I want to extract the values out for all rows. Intuitively, I would run this: values = sizes[:, threshold_indeces] and expect an output of [6, 7, 5, 9]. However, I get a 2D array of size NxN which in this case equals array([[6, 6, 6, 6], [3, 7, 7, 7], [3, 5, 5, 5], [3, 9, 9, 9]]).

I see my expected output is repeated in column 1, 2 and 3 but I get an memory allocation error when I run with millions of rows.

What am I doing wrong and how can I just get the 1D output of the values?


Solution

  • You are essentially telling numpy that you want the values in all the rows of the columns corresponding to the values in threshold_indeces (by the way, it's spelled indices). What you need is to also tell numpy the corresponding rows for each of the columns. This can be done by indexing the rows using np.arange.

    values = sizes[np.arange(threshold_indeces.size), threshold_indeces]