Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Faster way to split a numpy array according to a threshold

Suppose I have a random numpy array:

X = np.arange(1000)

and a threshold:

thresh = 50

I want to split X in two partitions X_l and X_r in such a way that every element in X_l is less or equal to thresh while in X_r each element is greater than thresh. After that these two partitions are given to a recursive function.

Using numpy I create a boolean array and I use it to partition X:

Z = X <= thresh
X_l, X_r = X[Z == 0], X[Z == 1]
recursive_call(X_l, X_r)

This is done several times, is there a way to make things faster? Is it possible to avoid creating a copy of the partitions at each call?

like image 835
blueSurfer Avatar asked Feb 16 '23 04:02

blueSurfer


1 Answers

X[~Z] is faster than X[Z==0]:

In [13]: import numpy as np

In [14]: X = np.random.random_integers(0, 1000, size=1000)

In [15]: thresh = 50

In [18]: Z = X <= thresh

In [19]: %timeit X_l, X_r = X[Z == 0], X[Z == 1]
10000 loops, best of 3: 23.9 us per loop

In [20]: %timeit X_l, X_r = X[~Z], X[Z]
100000 loops, best of 3: 16.4 us per loop

Have you profiled to determine that this is really the bottleneck in your code? If your code is spending only 1% of its time doing this splitting operation, then however much you optimize this operation will have no more than a 1% impact on the overall performance.

You might benefit more by rethinking your algorithm or data structures than optimizing this one operation. And if this is really the bottleneck, you might also do better by rewriting this piece of code in C or Cython...

When you have numpy arrays of size 1000, there is a chance that using Python lists/sets/dicts might be quicker. The speed benefit of NumPy arrays sometimes does not become apparent until the arrays are quite large. You might want to rewrite your code in pure Python and benchmark the two versions with timeit.

Hm, let me rephrase that. It's not really the size of the array which makes NumPy quicker or slower. Its just that having small NumPy arrays is sometimes a sign that you are creating lots of small NumPy arrays, and the creation of a NumPy array is significantly slower than the creation of, say, a Python list:

In [21]: %timeit np.array([])
100000 loops, best of 3: 4.31 us per loop

In [22]: %timeit []
10000000 loops, best of 3: 29.5 ns per loop

In [23]: 4310/295.
Out[23]: 14.610169491525424

Also, when you code in pure Python, you might be more likely to use dicts and sets for which there is no direct NumPy equivalent. That might lead you to an alternative algorithm which is quicker.

like image 89
unutbu Avatar answered May 02 '23 09:05

unutbu