Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Efficient thresholding filter of an array with numpy

I need to filter an array to remove the elements that are lower than a certain threshold. My current code is like this:

threshold = 5 a = numpy.array(range(10)) # testing data b = numpy.array(filter(lambda x: x >= threshold, a)) 

The problem is that this creates a temporary list, using a filter with a lambda function (slow).

As this is a quite simple operation, maybe there is a numpy function that does it in an efficient way, but I've been unable to find it.

I've thought that another way to achieve this could be sorting the array, finding the index of the threshold and returning a slice from that index onwards, but even if this would be faster for small inputs (and it won't be noticeable anyway), its definitively asymptotically less efficient as the input size grows.

Any ideas? Thanks!

Update: I took some measurements too, and the sorting+slicing was still twice as fast than the pure python filter when the input was 100.000.000 entries.

In [321]: r = numpy.random.uniform(0, 1, 100000000)  In [322]: %timeit test1(r) # filter 1 loops, best of 3: 21.3 s per loop  In [323]: %timeit test2(r) # sort and slice 1 loops, best of 3: 11.1 s per loop  In [324]: %timeit test3(r) # boolean indexing 1 loops, best of 3: 1.26 s per loop 
like image 961
fortran Avatar asked Nov 03 '11 11:11

fortran


1 Answers

b = a[a>threshold] this should do

I tested as follows:

import numpy as np, datetime # array of zeros and ones interleaved lrg = np.arange(2).reshape((2,-1)).repeat(1000000,-1).flatten()  t0 = datetime.datetime.now() flt = lrg[lrg==0] print datetime.datetime.now() - t0  t0 = datetime.datetime.now() flt = np.array(filter(lambda x:x==0, lrg)) print datetime.datetime.now() - t0 

I got

$ python test.py 0:00:00.028000 0:00:02.461000 

http://docs.scipy.org/doc/numpy/user/basics.indexing.html#boolean-or-mask-index-arrays

like image 135
yosukesabai Avatar answered Nov 05 '22 05:11

yosukesabai