Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does numpy.apply_along_axis seem to be slower than Python loop?

Tags:

python

numpy

I'm confused about when numpy's numpy.apply_along_axis() function will outperform a simple Python loop. For example, consider the case of a matrix with many rows, and you wish to compute the sum of each row:

x = np.ones([100000, 3])
sums1 = np.array([np.sum(x[i,:]) for i in range(x.shape[0])])
sums2 = np.apply_along_axis(np.sum, 1, x)

Here I am even using a built-in numpy function, np.sum, and yet calculating sums1 (Python loop) takes less than 400ms while calculating sums2 (apply_along_axis) takes over 2000ms (NumPy 1.6.1 on Windows). By further way of comparison, R's rowMeans function can often do this in less than 20ms (I'm pretty sure it's calling C code) while the similar R function apply() can do it in about 600ms.

like image 206
Abiel Avatar asked Dec 28 '11 06:12

Abiel


1 Answers

np.sum take an axis parameter, so you could compute the sum simply using

sums3 = np.sum(x, axis=1)

This is much faster than the 2 methods you posed.

$ python -m timeit -n 1 -r 1 -s "import numpy as np;x=np.ones([100000,3])" "np.apply_along_axis(np.sum, 1, x)"
1 loops, best of 1: 3.21 sec per loop

$ python -m timeit -n 1 -r 1 -s "import numpy as np;x=np.ones([100000,3])" "np.array([np.sum(x[i,:]) for i in range(x.shape[0])])"
1 loops, best of 1: 712 msec per loop

$ python -m timeit -n 1 -r 1 -s "import numpy as np;x=np.ones([100000,3])" "np.sum(x, axis=1)"
1 loops, best of 1: 1.81 msec per loop

(As for why apply_along_axis is slower — I don't know, probably because the function is written in pure Python and is much more generic and thus less optimization opportunity than the array version.)

like image 141
kennytm Avatar answered Oct 07 '22 00:10

kennytm