Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

NumPy vs Cython - nested loop so slow?

Tags:

python

numpy

I am confused how NumPy nested loop for 3D array is so slow in comparison with Cython. I wrote trivial example.

Python/NumPy version:

import numpy as np

def my_func(a,b,c):
    s=0
        for z in xrange(401):
        for y in xrange(401):
            for x in xrange(401):
                if a[z,y,x] == 0 and b[x,y,z] >= 0:
                    c[z,y,x] = 1
                    b[z,y,x] = z*y*x
                    s+=1
    return s

a = np.zeros((401,401,401), dtype=np.float32)
b = np.zeros((401,401,401), dtype=np.uint32)
c = np.zeros((401,401,401), dtype=np.uint8)

s = my_func(a,b,c)

Cythonized version:

cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def my_func(np.float32_t[:,:,::1] a, np.uint32_t[:,:,::1] b, np.uint8_t[:,:,::1] c):
    cdef np.uint16_t z,y,x
    cdef np.uint32_t s = 0

    for z in range(401):
        for y in range(401):
            for x in range(401):
                if a[z,y,x] == 0 and b[x,y,z] >= 0:
                    c[z,y,x] = 1
                    b[z,y,x] = z*y*x
                    s = s+1
    return s

Cythonized version of my_func() runs approx. 6500x faster. Simpler function only with if-statement and array access can be even 10000x faster. Python version of my_func() takes 500.651 sec. to finish. Is iterating over relatively small 3D array so slow or I made some mistake in code?

Cython version 0.21.1, Python 2.7.5, GCC 4.8.1, Xubuntu 13.10.

like image 812
MarcinBurz Avatar asked Dec 19 '22 10:12

MarcinBurz


2 Answers

Python is an interpreted language. One of the benefits of compiling to machine code is the huge speedup you get, especially with things like nested loops.

I don't know what your expectations are, but all interpreted languages will be terribly slow at the things you are trying to do (JIT compiling may help to some extent though).

The trick of getting good performance out of Numpy (or MATLAB or anything similar) is to avoid looping altogether and instead try to refactor your code into a few operations on large matrices. This way, the looping will take place in the (heavily optimized) machine code libraries instead of your Python code.

like image 137
Krumelur Avatar answered Jan 11 '23 06:01

Krumelur


As mentioned by Krumelur, python loops are definitely slow. You can, however, use numpy to your advantage. Operations on entire arrays are quite fast, although you need a little ingenuity sometimes.

For instance, in your code, since your loop never reads the value in b after you modify it (I think? My head is a little fuzzy at the moment, so you'll definitely want to go through this), the following should be equivalent:

# Precalculate a matrix of x*y*z
tmp = np.indices(a.shape)
prod = (tmp[:,:,:,0] * tmp[:,:,:,1] * tmp[:,:,:,2]).T

# Use array-wide logical operations to compute c using a and the transpose of b
condition = np.logical_and(a == 0, b.T >= 0)

# Use condition to alter b and c only where condition is true
b[condition] = prod[condition]
c[condition] = 1

s = condition.sum()

So this does calculate x*y*z even in cases where the condition is false. You could probably avoid that if it turns out that is using lots of time, but it's likely not to be a significant factor.

like image 29
Gretchen Avatar answered Jan 11 '23 05:01

Gretchen