Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to apply the output of numpy.argpartition for 2-D Arrays?

I have a largish 2d numpy array, and I want to extract the lowest 10 elements of each row as well as their indexes. Since my array is largish, I would prefer not to sort the whole array.

I heard about the argpartition() function, with which I can get the indexes of the lowest 10 elements:

top10indexes = np.argpartition(myBigArray,10)[:,:10]

Note that argpartition() partitions axis -1 by default, which is what I want. The result here has the same shape as myBigArray containing indexes into the respective rows such that the first 10 indexes point to the 10 lowest values.

How can I now extract the elements of myBigArray corresponding to those indexes?

Obvious fancy indexing like myBigArray[top10indexes] or myBigArray[:,top10indexes] do something quite different. I could also use list comprehensions, something like:

array([row[idxs] for row,idxs in zip(myBigArray,top10indexes)])

but that would incur a performance hit iterating numpy rows and converting the result back to an array.

nb: I could just use np.partition() to get the values, and they may even correspond to the indexes (or may not..), but I don't want to do the partition twice if I can avoid it.

like image 694
drevicko Avatar asked Oct 12 '14 05:10

drevicko


People also ask

How do I print a 2D array in NumPy?

Python iterate numpy 2d array In Python to iterate a 2-dimensional array we can easily use for loop() method in it and for creating a numpy array we can apply the arange() function along with reshape. Once you will print 'i' then the output will display the 2-dimensional array.

What does NumPy Argpartition do?

argpartition() function is used to create a indirect partitioned copy of input array with its elements rearranged in such a way that the value of the element in k-th position is in the position it would be in a sorted array.

What is NumPy and how 2D array works?

2D array are also called as Matrices which can be represented as collection of rows and columns. In this article, we have explored 2D array in Numpy in Python. NumPy is a library in python adding support for large multidimensional arrays and matrices along with high level mathematical functions to operate these arrays.


1 Answers

You can avoid using the flattened copies and the need to extract all the values by doing:

num = 10
top = np.argpartition(myBigArray, num, axis=1)[:, :num]
myBigArray[np.arange(myBigArray.shape[0])[:, None], top]

For NumPy >= 1.9.0 this will be very efficient and comparable to np.take().

like image 136
Saullo G. P. Castro Avatar answered Oct 08 '22 00:10

Saullo G. P. Castro