Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cython optimize the critical part of a numpy array summation

Let L be a list L = [A_1, A_2, ..., A_n], and each of the A_i are numpy.int32 arrays of length 1024.

(Most of the time 1000 < n < 4000).

After some profiling, I have seen that one the most time consuming operation is the summation :

def summation():
    # L is a global variable, modified outside of this function
    b = numpy.zeros(1024, numpy.int32)
    for a in L:
        b += a
    return b

PS : I don't think I can define a 2D array of size 1024 x n because n is not fixed : some elements are removed / added to L dynamically, so len(L) = n can vary between 1000 and 4000 during the running time.

Can I get a significant improvement by using Cython ? If so, how should I cython-recode this small function (shouldn't I add some cdef typing ?)

Or can you see some possible other improvements ?

like image 779
Basj Avatar asked Feb 06 '14 22:02

Basj


1 Answers

Here is the Cython code, make sure that every array in L is C_CONTIGUOUS:

import cython
import numpy as np
cimport numpy as np

@cython.boundscheck(False)
@cython.wraparound(False)
def sum_list(list a):
    cdef int* x
    cdef int* b
    cdef int i, j
    cdef int count
    count = len(a[0])
    res = np.zeros_like(a[0])
    b = <int *>((<np.ndarray>res).data)
    for j in range(len(a)):
        x = <int *>((<np.ndarray>a[j]).data)
        for i in range(count):
            b[i] += x[i]
    return res

One my PC it's about 4x faster.

like image 77
HYRY Avatar answered Nov 16 '22 18:11

HYRY