Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sort a numpy array by another array, along a particular axis

Similar to this answer, I have a pair of 3D numpy arrays, a and b, and I want to sort the entries of b by the values of a. Unlike this answer, I want to sort only along one axis of the arrays.

My naive reading of the numpy.argsort() documentation:

Returns ------- index_array : ndarray, int     Array of indices that sort `a` along the specified axis.     In other words, ``a[index_array]`` yields a sorted `a`. 

led me to believe that I could do my sort with the following code:

import numpy  a = numpy.zeros((3, 3, 3)) a += numpy.array((1, 3, 2)).reshape((3, 1, 1)) print "a" print a """ [[[ 1.  1.  1.]   [ 1.  1.  1.]   [ 1.  1.  1.]]   [[ 3.  3.  3.]   [ 3.  3.  3.]   [ 3.  3.  3.]]   [[ 2.  2.  2.]   [ 2.  2.  2.]   [ 2.  2.  2.]]] """ b = numpy.arange(3*3*3).reshape((3, 3, 3)) print "b" print b """ [[[ 0  1  2]   [ 3  4  5]   [ 6  7  8]]   [[ 9 10 11]   [12 13 14]   [15 16 17]]   [[18 19 20]   [21 22 23]   [24 25 26]]] """ print "a, sorted" print numpy.sort(a, axis=0) """ [[[ 1.  1.  1.]   [ 1.  1.  1.]   [ 1.  1.  1.]]   [[ 2.  2.  2.]   [ 2.  2.  2.]   [ 2.  2.  2.]]   [[ 3.  3.  3.]   [ 3.  3.  3.]   [ 3.  3.  3.]]] """  ##This isnt' working how I'd like sort_indices = numpy.argsort(a, axis=0) c = b[sort_indices] """ Desired output:  [[[ 0  1  2]   [ 3  4  5]   [ 6  7  8]]   [[18 19 20]   [21 22 23]   [24 25 26]]   [[ 9 10 11]   [12 13 14]   [15 16 17]]] """ print "Desired shape of b[sort_indices]: (3, 3, 3)." print "Actual shape of b[sort_indices]:" print c.shape """ (3, 3, 3, 3, 3) """ 

What's the right way to do this?

like image 647
Andrew Avatar asked May 27 '11 17:05

Andrew


People also ask

How do you sort an array based on another array?

Method 1 (Using Sorting and Binary Search)Create a temporary array temp of size m and copy the contents of A1[] to it. Create another array visited[] and initialize all entries in it as false. visited[] is used to mark those elements in temp[] which are copied to A1[]. Initialize the output index ind as 0.

How do I sort a NumPy array by a specific column?

NumPy arrays can be sorted by a single column, row, or by multiple columns or rows using the argsort() function. The argsort function returns a list of indices that will sort the values in an array in ascending value.

How do I get the indices of sorted array NumPy?

We can get the indices of the sorted elements of a given array with the help of argsort() method. This function is used to perform an indirect sort along the given axis using the algorithm specified by the kind keyword.


1 Answers

You still have to supply indices for the other two dimensions for this to work correctly.

>>> a = numpy.zeros((3, 3, 3)) >>> a += numpy.array((1, 3, 2)).reshape((3, 1, 1)) >>> b = numpy.arange(3*3*3).reshape((3, 3, 3)) >>> sort_indices = numpy.argsort(a, axis=0) >>> static_indices = numpy.indices((3, 3, 3)) >>> b[sort_indices, static_indices[1], static_indices[2]] array([[[ 0,  1,  2],         [ 3,  4,  5],         [ 6,  7,  8]],         [[18, 19, 20],         [21, 22, 23],         [24, 25, 26]],         [[ 9, 10, 11],         [12, 13, 14],         [15, 16, 17]]]) 

numpy.indices calculates the indices of each axis of the array when "flattened" through the other two axes (or n - 1 axes where n = total number of axes). In other words, this (apologies for the long post):

>>> static_indices array([[[[0, 0, 0],          [0, 0, 0],          [0, 0, 0]],          [[1, 1, 1],          [1, 1, 1],          [1, 1, 1]],          [[2, 2, 2],          [2, 2, 2],          [2, 2, 2]]],          [[[0, 0, 0],          [1, 1, 1],          [2, 2, 2]],          [[0, 0, 0],          [1, 1, 1],          [2, 2, 2]],          [[0, 0, 0],          [1, 1, 1],          [2, 2, 2]]],          [[[0, 1, 2],          [0, 1, 2],          [0, 1, 2]],          [[0, 1, 2],          [0, 1, 2],          [0, 1, 2]],          [[0, 1, 2],          [0, 1, 2],          [0, 1, 2]]]]) 

These are the identity indices for each axis; when used to index b, they recreate b.

>>> b[static_indices[0], static_indices[1], static_indices[2]] array([[[ 0,  1,  2],         [ 3,  4,  5],         [ 6,  7,  8]],         [[ 9, 10, 11],         [12, 13, 14],         [15, 16, 17]],         [[18, 19, 20],         [21, 22, 23],         [24, 25, 26]]]) 

As an alternative to numpy.indices, you could use numpy.ogrid, as unutbu suggests. Since the object generated by ogrid is smaller, I'll create all three axes, just for consistency sake, but note unutbu's comment for a way to do this by generating only two.

>>> static_indices = numpy.ogrid[0:a.shape[0], 0:a.shape[1], 0:a.shape[2]] >>> a[sort_indices, static_indices[1], static_indices[2]] array([[[ 1.,  1.,  1.],         [ 1.,  1.,  1.],         [ 1.,  1.,  1.]],         [[ 2.,  2.,  2.],         [ 2.,  2.,  2.],         [ 2.,  2.,  2.]],         [[ 3.,  3.,  3.],         [ 3.,  3.,  3.],         [ 3.,  3.,  3.]]]) 
like image 140
senderle Avatar answered Oct 03 '22 10:10

senderle