Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

fast python numpy where functionality?

I am using numpy's where function many times inside several for loops, but it becomes way too slow. Are there any ways to perform this functionality faster? I read you should try to do in-line for loops, as well as make local variables for functions before the for loops, but nothing seems to improve speed by much (< 1%). The len(UNIQ_IDS) ~ 800. emiss_data and obj_data are numpy ndarrays with shape = (2600,5200). I've used import profile to get a handle on where the bottlenecks are, and where in for loops is a big one.

import numpy as np
max = np.max
where = np.where
MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS)]
like image 210
weather guy Avatar asked Aug 26 '13 20:08

weather guy


People also ask

Why NumPy Where is fast?

NumPy is fast because it can do all its calculations without calling back into Python. Since this function involves looping in Python, we lose all the performance benefits of using NumPy. For a 10,000,000-entry NumPy array, this functions takes 2.5 seconds to run on my computer.

Where is NumPy function?

The numpy. where() function returns the indices of elements in an input array where the given condition is satisfied. Parameters: condition : When True, yield x, otherwise yield y.

Is NumPy where faster than for loop?

NumPy operations are much faster than pure Python operations when you can find corresponding functions in NumPy to replace single for loops. Double for loops can sometimes be replaced by the NumPy broadcasting operation and it can save even more computational time.


2 Answers

Can't you just do

emiss_data[obj_data == i]

? I'm not sure why you're using where at all.

like image 89
user2357112 supports Monica Avatar answered Oct 21 '22 09:10

user2357112 supports Monica


It turns out that a pure Python loop can be much much faster than NumPy indexing (or calls to np.where) in this case.

Consider the following alternatives:

import numpy as np
import collections
import itertools as IT

shape = (2600,5200)
# shape = (26,52)
emiss_data = np.random.random(shape)
obj_data = np.random.random_integers(1, 800, size=shape)
UNIQ_IDS = np.unique(obj_data)

def using_where():
    max = np.max
    where = np.where
    MAX_EMISS = [max(emiss_data[where(obj_data == i)]) for i in UNIQ_IDS]
    return MAX_EMISS

def using_index():
    max = np.max
    MAX_EMISS = [max(emiss_data[obj_data == i]) for i in UNIQ_IDS]
    return MAX_EMISS

def using_max():
    MAX_EMISS = [(emiss_data[obj_data == i]).max() for i in UNIQ_IDS]
    return MAX_EMISS

def using_loop():
    result = collections.defaultdict(list)
    for val, idx in IT.izip(emiss_data.ravel(), obj_data.ravel()):
        result[idx].append(val)
    return [max(result[idx]) for idx in UNIQ_IDS]

def using_sort():
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
    vals = uind.argsort()
    count = np.bincount(uind)
    start = 0
    end = 0
    out = np.empty(count.shape[0])
    for ind, x in np.ndenumerate(count):
        end += x
        out[ind] = np.max(np.take(emiss_data, vals[start:end]))
        start += x
    return out

def using_split():
    uind = np.digitize(obj_data.ravel(), UNIQ_IDS) - 1
    vals = uind.argsort()
    count = np.bincount(uind)
    return [np.take(emiss_data, item).max()
            for item in np.split(vals, count.cumsum())[:-1]]

for func in (using_index, using_max, using_loop, using_sort, using_split):
    assert using_where() == func()

Here are the benchmarks, with shape = (2600,5200):

In [57]: %timeit using_loop()
1 loops, best of 3: 9.15 s per loop

In [90]: %timeit using_sort()
1 loops, best of 3: 9.33 s per loop

In [91]: %timeit using_split()
1 loops, best of 3: 9.33 s per loop

In [61]: %timeit using_index()
1 loops, best of 3: 63.2 s per loop

In [62]: %timeit using_max()
1 loops, best of 3: 64.4 s per loop

In [58]: %timeit using_where()
1 loops, best of 3: 112 s per loop

Thus using_loop (pure Python) turns out to be more than 11x faster than using_where.

I'm not entirely sure why pure Python is faster than NumPy here. My guess is that the pure Python version zips (yes, pun intended) through both arrays once. It leverages the fact that despite all the fancy indexing, we really just want to visit each value once. Thus it side-steps the issue with having to determine exactly which group each value in emiss_data falls in. But this is just vague speculation. I didn't know it would be faster until I benchmarked.

like image 28
unutbu Avatar answered Oct 21 '22 08:10

unutbu