Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sort each column of an numpy.ndarray using the output of numpy.argsort

I would like to sort numpy 2D arrays according to a previously processed reference array. My idea was to store the numpy.argsort output of my reference array and using it to sort the other ones:

In [13]: # my reference array
    ...: ref_arr = np.random.randint(10, 30, 12).reshape(3, 4)
Out[14]:
array([[12, 22, 12, 13],
       [28, 26, 21, 23],
       [24, 14, 16, 25]])

# desired output:
array([[12, 14, 12, 13],
       [24, 22, 16, 23],
       [28, 26, 21, 25]])

What I tried:

In [15]: # store the sorting matrix
    ...: sm = np.argsort(ref_arr, axis=0)
Out[16]:
array([[0, 2, 0, 0],
       [2, 0, 2, 1],
       [1, 1, 1, 2]])

But unfortunately the final step does only work with one dimensional arrays:

In [17]: ref_arr[sm]
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-17-48b785178465> in <module>()
----> 1 ref_arr[sm]

IndexError: index 3 is out of bounds for axis 0 with size 3

I found this Github issue that was created in regard to this problem but, unfortunately, it was solved by mentioning that what I tried works for 1D arrays only. 🙄

In a comment to this issue an example is mentioned that is similar to my problem. The snippet does not solve my problem as it sorts the array by row and not by column. But it gives a hint in which direction I have to move...

a[np.arange(np.shape(a)[0])[:,np.newaxis], np.argsort(a)]

Unfortunately I don't understand the example enough to adapt it for my use case. Maybe someone can explain how this advanced indexing works here? That might enable me to solve the problem on my own but ofc I wouldn't mind a turnkey solution as well. ;)

Thank you.

Just in case: I am using Python 3.6.1 and numpy 1.12.1 on OS X.

like image 364
wedi Avatar asked Apr 04 '17 22:04

wedi


Video Answer


1 Answers

As of may 2018 it can done using np.take_along_axis

np.take_along_axis(ref_arr, sm, axis=0)
Out[25]: 
array([[10, 16, 15, 10],
       [13, 23, 24, 12],
       [28, 26, 28, 28]])
like image 139
vozman Avatar answered Sep 28 '22 02:09

vozman