Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

applying numpy.searchsorted over multiple axes

I'm trying to pass a 2d array of sorted values and a 1d array of values to searchsorted and have it return a corresponding 1d array of index values. I have to perform this operation many times, and I'm trying to find an efficient way to do so. Ultimately, I'd like to pass a 3d array of values sorted along an axis, a 2d array of values to search for and have the function return a 2d array of indices.

Here is a MLE

from numpy import *
random.seed(1234)

a = zeros((4,10))
a[:,1:-1] = random.rand(4,8)
a[:,1:-1].sort(1)
a[:,-1] = 1.

v = random.rand(4)

res = array([searchsorted(a[j], v[j]) for j in xrange(4)])

where res should be [9, 1, 7, 6]

Is there an efficient way to do this? Ideally, I'd like to avoid Cython if at all possible. If it helps, each array being searched over should have 9-10 elements, while the number of values being searched for is larger (100-1000 elements) so the 2d array being passed would have a size of 1000x10.

like image 825
user1554752 Avatar asked May 27 '14 22:05

user1554752


1 Answers

You can replace

res = array([searchsorted(a[j], v[j]) for j in xrange(4)])

with:

res = (a.T < v).sum(axis=0)

That's it (I'll write the proof of correctness in the comment). In the above line, although we do not binary search, since we vectorize the iteration on the larger dimension, we get a 40x speedup (on my Mac). This speed up is gained with the 2d array size of (1000,10).

In order to measure the execution, you can use the following command to average the runtime (1.67 milli seconds) of your original command in 100 executions:

python2 -m timeit -n 100 \
    -s 'import numpy as np; np.random.seed(1234); a = np.zeros((1000,10)); a[:,1:-1] = np.random.rand(1000,8); a[:,1:-1].sort(1); a[:,-1] = 1.; v = np.random.rand(1000)' \
    'res = np.array([np.searchsorted(a[j], v[j]) for j in xrange(1000)])'

100 loops, best of 3: 1.67 msec per loop

while the proposed execution takes only 41.7 micro seconds:

python2 -m timeit -n 100 \
    -s 'import numpy as np; np.random.seed(1234); a = np.zeros((1000,10)); a[:,1:-1] = np.random.rand(1000,8); a[:,1:-1].sort(1); a[:,-1] = 1.; v = np.random.rand(1000)' \
    'res = (a.T < v).sum(axis=0)'

100 loops, best of 3: 41.7 usec per loop
like image 80
mohammads Avatar answered Oct 06 '22 03:10

mohammads