Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

creating small arrays in cython takes a humongous amount of time

I was writing a new random number generator for numpy that produces random numbers according to an arbitrary distribution when I came across this really weird behavior:

this is test.pyx

#cython: boundscheck=False
#cython: wraparound=False
import numpy as np
cimport numpy as np
cimport cython

def BareBones(np.ndarray[double, ndim=1] a,np.ndarray[double, ndim=1] u,r):
    return u

def UntypedWithLoop(a,u,r):
    cdef int i,j=0
    for i in range(u.shape[0]):
        j+=i
    return u,j

def BSReplacement(np.ndarray[double, ndim=1] a, np.ndarray[double, ndim=1] u):
    cdef np.ndarray[np.int_t, ndim=1] r=np.empty(u.shape[0],dtype=int)
    cdef int i,j=0
    for i in range(u.shape[0]):
        j=i
    return r

setup.py

from distutils.core import setup
from Cython.Build import cythonize
setup(name = "simple cython func",ext_modules = cythonize('test.pyx'),)

profiling code

#!/usr/bin/python
from __future__ import division

import subprocess
import timeit

#Compile the cython modules before importing them
subprocess.call(['python', 'setup.py', 'build_ext', '--inplace'])

sstr="""
import test
import numpy
u=numpy.random.random(10)
a=numpy.random.random(10)
a=numpy.cumsum(a)
a/=a[-1]
r=numpy.empty(10,int)
"""

print "binary search: creates an array[N] and performs N binary searches to fill it:\n",timeit.timeit('numpy.searchsorted(a,u)',sstr)
print "Simple replacement for binary search:takes the same args as np.searchsorted and similarly returns a new array. this performs only one trivial operation per element:\n",timeit.timeit('test.BSReplacement(a,u)',sstr)

print "barebones function doing nothing:",timeit.timeit('test.BareBones(a,u,r)',sstr)
print "Untyped inputs and doing N iterations:",timeit.timeit('test.UntypedWithLoop(a,u,r)',sstr)
print "time for just np.empty()",timeit.timeit('numpy.empty(10,int)',sstr)

The binary search implementation takes in the order of len(u)*Log(len(a)) time to execute. The trivial cython function takes in the order of len(u) to run. Both return a 1D int array of len(u).

however, even this no computation trivial implementation takes longer than the full binary search in the numpy library. (it was written in C: https://github.com/numpy/numpy/blob/202e78d607515e0390cffb1898e11807f117b36a/numpy/core/src/multiarray/item_selection.c see PyArray_SearchSorted)

The results are:

binary search: creates an array[N] and performs N binary searches to fill it:
1.15157485008
Simple replacement for binary search:takes the same args as np.searchsorted and similarly returns a new array. this performs only one trivial operation per element:
3.69442796707
barebones function doing nothing: 0.87496304512
Untyped inputs and doing N iterations: 0.244267940521
time for just np.empty() 1.0983929634

Why is the np.empty() step taking so much time? and what can I do to get an empty array that I can return ?

The C function does this AND runs a whole bunch of sanity checks AND uses a longer algorithm in the inner loop. (i removed all the logic except the loop itself fro my example)


Update

It turns out there are two distinct problems:

  1. The np.empty(10) call alone has a ginormous overhead and takes as much time as it takes for searchsorted to make a new array AND perform 10 binary searches on it
  2. Just declaring the buffer syntax np.ndarray[...] also has a massive overhead that takes up MORE time than receiving the untyped variables AND iterating 50 times.

results for 50 iterations:

binary search: 2.45336699486
Simple replacement:3.71126317978
barebones function doing nothing: 0.924916028976
Untyped inputs and doing N iterations: 0.316384077072
time for just np.empty() 1.04949498177
like image 678
staticd Avatar asked Aug 23 '13 19:08

staticd


People also ask

Does NumPy work with Cython?

You can use NumPy from Cython exactly the same as in regular Python, but by doing so you are losing potentially high speedups because Cython has support for fast access to NumPy arrays.

What is Cimport Cython?

The cimport statement is used in a definition or implementation file to gain access to names declared in another definition file. Its syntax exactly parallels that of the normal Python import statement. When pure python syntax is used, the same effect can be done by importing from special cython.


2 Answers

There is a discussion of this on the Cython list that might have some useful suggestions: https://groups.google.com/forum/#!topic/cython-users/CwtU_jYADgM

Generally though I try to allocate small arrays outside of Cython, pass them in and re-use them in subsequent calls to the method. I understand that this is not always an option.

like image 192
JoshAdel Avatar answered Sep 24 '22 10:09

JoshAdel


Creating np.empty inside the Cython function has some overhead as you already saw. Here you will see an example about how to create the empty array and pass it to the Cython module in order to fill with the correct values:

n=10:

numpy.searchsorted: 1.30574745517
cython O(1): 3.28732016088
cython no array declaration 1.54710909596

n=100:

numpy.searchsorted: 4.15200545373
cython O(1): 13.7273431067
cython no array declaration 11.4186086744

As you already pointed out, the numpy version scales better since it is O(len(u)*long(len(a))) and this algorithm here is O(len(u)*len(a))...

I also tried to use Memoryview, basically changing np.ndarray[double, ndim=1] by double[:], but the first option was faster in this case.

The new .pyx file is:

from __future__ import division
import numpy as np
cimport numpy as np
cimport cython

@cython.boundscheck(False)
@cython.wraparound(False)
def JustLoop(np.ndarray[double, ndim=1] a, np.ndarray[double, ndim=1] u,
             np.ndarray[int, ndim=1] r):
    cdef int i,j
    for j in range(u.shape[0]):
        if u[j] < a[0]:
            r[j] = 0
            continue

        if u[j] > a[a.shape[0]-1]:
            r[j] = a.shape[0]-1
            continue

        for i in range(1, a.shape[0]):
            if u[j] >= a[i-1] and u[j] < a[i]:
                r[j] = i
                break

@cython.boundscheck(False)
@cython.wraparound(False)
def WithArray(np.ndarray[double, ndim=1] a, np.ndarray[double, ndim=1] u):
    cdef np.ndarray[np.int_t, ndim=1] r=np.empty(u.shape[0],dtype=int)
    cdef int i,j
    for j in range(u.shape[0]):
        if u[j] < a[0]:
            r[j] = 0
            continue

        if u[j] > a[a.shape[0]-1]:
            r[j] = a.shape[0]-1
            continue

        for i in range(1, a.shape[0]):
            if u[j] >= a[i-1] and u[j] < a[i]:
                r[j] = i
                break
    return r

The new .py file:

import numpy
import subprocess
import timeit

#Compile the cython modules before importing them
subprocess.call(['python', 'setup.py', 'build_ext', '--inplace'])
from test import *

sstr="""
import test
import numpy
u=numpy.random.random(10)
a=numpy.random.random(10)
a=numpy.cumsum(a)
a/=a[-1]
a.sort()
r = numpy.empty(u.shape[0], dtype=int)
"""

print "numpy.searchsorted:",timeit.timeit('numpy.searchsorted(a,u)',sstr)
print "cython O(1):",timeit.timeit('test.WithArray(a,u)',sstr)
print "cython no array declaration",timeit.timeit('test.JustLoop(a,u,r)',sstr)
like image 39
Saullo G. P. Castro Avatar answered Sep 25 '22 10:09

Saullo G. P. Castro