Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to filter a numpy array by a set of values

Tags:

python

numpy

I am pretty new to numpy, I also am using pypy 2.2 which has limited numpy support (see http://buildbot.pypy.org/numpy-status/latest.html) but what I'm trying to do is filter an array by a set of values (i.e keep subarray if it contains a value in a set). I can do with a list comprehension but I'd rather do without the intermediate list as on longer arrays it isn't fast and I can't help but think numpy filtering will be faster.

>> a = np.array([[   368,    322, 175238,      2],
       [   430,    382, 121486,      2],
       [   451,    412, 153521,      2],
       [   480,    442, 121468,      2],
       [   517,    475, 109543,      2],
       [   543,    503, 121471,      2],
       [   576,    537, 100566,      2],
       [   607,    567, 121473,      2],
       [   640,    597, 153561,      2]])

>> b = {121486, 153521, 121473}

>> np.array([x for x in a if x[2] in b])

>> array([[   430,    382, 121486,      2],
   [   451,    412, 153521,      2],
   [   607,    567, 121473,      2]])
like image 886
Handloomweaver Avatar asked Dec 02 '13 16:12

Handloomweaver


People also ask

How do you filter an array in NumPy?

In NumPy, you filter an array using a boolean index list. A boolean index list is a list of booleans corresponding to indexes in the array. If the value at an index is True that element is contained in the filtered array, if the value at that index is False that element is excluded from the filtered array.

What is the default value for the axis in NumPy?

The default value for the axis is None, If none the flatten array is used. In this example, we filter the numpy array by a list of indexes by using the np. take () function and passed the axis=0 to filtering the numpy array row-wise.

How to create NumPy array with less than (<) and greater than (<=)?

Less Than (<) numpy.less (). Not Equal (!=) or numpy.not_equal (). Greater than and equal to (>=). Less than Equal to (<=). Step 1: First install NumPy in your system or Environment. By using the following command. Step 2: Import NumPy module. Step 3: Create an array of elements using NumPy Array method.

What is the filter for values greater than 5 and less than 9?

This filter returns the values in the NumPy array that are less than 5 or greater than 9. #filter for values greater than 5 and less than 9 my_array [ (my_array > 5) & (my_array < 9)] array ( [6, 7]) This filter returns the values in the NumPy array that are greater than 5 and less than 9.


1 Answers

You can do it in one line, but you have to use list(b), so it might not actually be any faster:

>>> a[np.in1d(a[:,2], list(b))]
array([[   430,    382, 121486,      2],
       [   451,    412, 153521,      2],
       [   607,    567, 121473,      2]])

It works because np.in1d tells you which of the first item are in the second:

>>> np.in1d(a[:,2], list(b))
array([False,  True,  True, False, False, False, False,  True, False], dtype=bool)

For large a and b, this is probably faster than your solution, as it still uses b as a set, but builds only boolean array instead of rebuilding the entire array one line at a time. For large a and small b, I think np.in1d might be faster.

ainb = np.array([x in b for x in a[:,2]])
a[ainb]

For small a and large b, your own solution is probably fastest.

like image 199
askewchan Avatar answered Oct 08 '22 05:10

askewchan