Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

sort numpy array with custom predicate

I'd like to sort my numpy array of shape [n,4], along first dimension (size:n) using a custom predicate operating on the 2nd dimension vector (size:4). The C++ version of what I'd like to do is below, it's quite simple really. I've seen how to do this with python lists, but I can't find the syntax to do it with numpy arrays. Is this possible? The documentation on np.sort, np.argsort, np.lexsort doesn't mention custom predicates.

// c++ version
vector< float[4] > v = init_v(); 
float[4] p = init_p();
std::sort(v.begin(), v.end(), [&p](const auto& lhs, const auto& rhs) {
   return myfn(p, lhs) > myfn(p, rhs); });

EDIT: below is the python code I would like to use for the sorting. I.e. for each 'row' (n:4) of my array, I'd calculate the square of the euclidean 3D distance (i.e. only the first 3 columns) to a fixed point.

# these both operate on numpy vectors of shape [4] (i.e. a single row of my data matrix)
def dist_sq(a,b):
    d = a[:3]-b[:3]
    return np.dot(d*d)

def sort_pred(lhs, rhs, p):
    return dist_sq(lhs, p) > dist_sq(rhs, p)
like image 372
memo Avatar asked Aug 21 '17 14:08

memo


People also ask

What does Argsort () do in Python?

Returns the indices that would sort an array. Perform an indirect sort along the given axis using the algorithm specified by the kind keyword. It returns an array of indices of the same shape as a that index data along the given axis in sorted order.

How do I arrange in NumPy ascending order?

Use numpy. sort() function to sort the elements of NumPy array in an ordered sequence. The parameter arr is mandatory. If you execute this function on a one-dimensional array, it will return a one-dimensional sorted array containing elements in ascending order.

What is the difference between Argsort and sort?

sort() returns the sorted array whereas np. argsort() returns an array of the corresponding indices. The figure shows how the algorithm transforms an unsorted array [10, 6, 8, 2, 5, 4, 9, 1] into a sorted array [1, 2, 4, 5, 6, 8, 9, 10] .


1 Answers

In numpy you would apply the (vectorized) order defining function to the array, then use np.argsort to sort by the result.

This is less space efficient than the C++ version, but that is how you usually achieve performance with numpy.

import numpy as np    

def myfn(x):
    return np.sin(x[:, 1])  # example: sort by the sine of the second column

a = np.random.randn(10, 4)

predicate = myfn(a)  # not sure if predicate is the best name for this variable
order = np.argsort(predicate)

a_sorted = a[order]
like image 69
MB-F Avatar answered Oct 19 '22 06:10

MB-F