Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Filter rows of a numpy array?

I am looking to apply a function to each row of a numpy array. If this function evaluates to true I will keep the row, otherwise I will discard it. For example, my function might be:

def f(row):     if sum(row)>10: return True     else: return False 

I was wondering if there was something similar to:

np.apply_over_axes() 

which applies a function to each row of a numpy array and returns the result. I was hoping for something like:

np.filter_over_axes() 

which would apply a function to each row of an numpy array and only return rows for which the function returned true. Is there anything like this? Or should I just use a for loop?

like image 867
killajoule Avatar asked Oct 02 '14 05:10

killajoule


People also ask

How do I filter a NumPy array in Python?

In NumPy, you filter an array using a boolean index list. A boolean index list is a list of booleans corresponding to indexes in the array. If the value at an index is True that element is contained in the filtered array, if the value at that index is False that element is excluded from the filtered array.

How do I slice rows in NumPy?

Slice Two-dimensional Numpy Arrays To slice elements from two-dimensional arrays, you need to specify both a row index and a column index as [row_index, column_index] . For example, you can use the index [1,2] to query the element at the second row, third column in precip_2002_2013 .


1 Answers

Ideally, you would be able to implement a vectorized version of your function and use that to do boolean indexing. For the vast majority of problems this is the right solution. Numpy provides quite a few functions that can act over various axes as well as all the basic operations and comparisons, so most useful conditions should be vectorizable.

import numpy as np  x = np.random.randn(20, 3) x_new = x[np.sum(x, axis=1) > .5] 

If you are absolutely sure that you can't do the above, I would suggest using a list comprehension (or np.apply_along_axis) to create an array of bools to index with.

def myfunc(row):     return sum(row) > .5  bool_arr = np.array([myfunc(row) for row in x]) x_new = x[bool_arr] 

This will get the job done in a relatively clean way, but will be significantly slower than a vectorized version. An example:

x = np.random.randn(5000, 200)  %timeit x[np.sum(x, axis=1) > .5] # 100 loops, best of 3: 5.71 ms per loop  %timeit x[np.array([myfunc(row) for row in x])] # 1 loops, best of 3: 217 ms per loop 
like image 67
Roger Fan Avatar answered Oct 05 '22 13:10

Roger Fan