Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Cython speedup isn't as large as expected

I have written a Python function that computes pairwise electromagnetic interactions between a largish number (N ~ 10^3) of particles and stores the results in an NxN complex128 ndarray. It runs, but it is the slowest part of a larger program, taking about 40 seconds when N=900 [corrected]. The original code looks like this:

import numpy as np
def interaction(s,alpha,kprop): # s is an Nx3 real array 
                                # alpha is complex
                                # kprop is float

    ndipoles = s.shape[0]

    Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=np.complex128)
    I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    im = complex(0,1)

    k2 = kprop*kprop

    for i in range(ndipoles):
        xi = s[i,:]
        for j in range(ndipoles):
            if i != j:
                xj = s[j,:]
                dx = xi-xj
                R = np.sqrt(dx.dot(dx))
                n = dx/R
                kR = kprop*R
                kR2 = kR*kR
                A = ((1./kR2) - im/kR)
                nxn = np.outer(n, n)
                nxn = (3*A-1)*nxn + (1-A)*I
                nxn *= -alpha*(k2*np.exp(im*kR))/R
            else:
                nxn = I

            Amat[i,:,j,:] = nxn

    return(Amat.reshape((3*ndipoles,3*ndipoles)))

I had never previously used Cython, but that seemed like a good place to start in my effort to speed things up, so I pretty much blindly adapted the techniques I found in online tutorials. I got some speedup (30 seconds vs. 40 seconds), but not nearly as dramatic as I expected, so I'm wondering whether I'm doing something wrong or am missing a critical step. The following is my best attempt at cythonizing the above routine:

import numpy as np
cimport numpy as np

DTYPE = np.complex128
ctypedef np.complex128_t DTYPE_t

def interaction(np.ndarray s, DTYPE_t alpha, float kprop):

    cdef float k2 = kprop*kprop
    cdef int i,j
    cdef np.ndarray xi, xj, dx, n, nxn
    cdef float R, kR, kR2
    cdef DTYPE_t A

    cdef int ndipoles = s.shape[0]
    cdef np.ndarray Amat = np.zeros((ndipoles,3, ndipoles, 3), dtype=DTYPE)
    cdef np.ndarray I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    cdef DTYPE_t im = complex(0,1)

    for i in range(ndipoles):
        xi = s[i,:]
        for j in range(ndipoles):
            if i != j:
                xj = s[j,:]
                dx = xi-xj
                R = np.sqrt(dx.dot(dx))
                n = dx/R
                kR = kprop*R
                kR2 = kR*kR
                A = ((1./kR2) - im/kR)
                nxn = np.outer(n, n)
                nxn = (3*A-1)*nxn + (1-A)*I
                nxn *= -alpha*(k2*np.exp(im*kR))/R
            else:
                nxn = I

            Amat[i,:,j,:] = nxn

    return(Amat.reshape((3*ndipoles,3*ndipoles)))
like image 431
Grant Petty Avatar asked Oct 11 '15 17:10

Grant Petty


1 Answers

The real power of NumPy is in performing an operation across a huge number of elements in a vectorized manner instead of using that operation in chunks spread across loops. In your case, you are using two nested loops and one IF conditional statement. I would propose extending the dimensions of the intermediate arrays, which would bring in NumPy's powerful broadcasting capability to come into play and thus the same operations could be used on all elements in one go instead of small chunks of data within the loops.

For extending the dimensions, None/np.newaxis could be used. So, the vectorized implementation to follow such a premise would look like this -

def vectorized_interaction(s,alpha,kprop):

    im = complex(0,1)
    I = np.array([[1,0,0],[0,1,0],[0,0,1]])
    k2 = kprop*kprop

    # Vectorized calculations for dx, R, n, kR, A
    sd = s[:,None] - s 
    Rv = np.sqrt((sd**2).sum(2))
    nv = sd/Rv[:,:,None]
    kRv = Rv*kprop
    Av = (1./(kRv*kRv)) - im/kRv

    # Vectorized calculation for: "nxn = np.outer(n, n)"
    nxnv = nv[:,:,:,None]*nv[:,:,None,:]

    # Vectorized calculation for: "(3*A-1)*nxn + (1-A)*I"
    P = (3*Av[:,:,None,None]-1)*nxnv + (1-Av[:,:,None,None])*I

    # Vectorized calculation for: "-alpha*(k2*np.exp(im*kR))/R"    
    multv = -alpha*(k2*np.exp(im*kRv))/Rv

    # Vectorized calculation for: "nxn *= -alpha*(k2*np.exp(im*kR))/R"   
    outv = P*multv[:,:,None,None]


    # Simulate ELSE part of the conditional statement"if i != j:" 
    # with masked setting to I on the last two dimensions
    outv[np.eye((N),dtype=bool)] = I

    return outv.transpose(0,2,1,3).reshape(N*3,-1)

Runtime tests and output verification -

Case #1:

In [703]: N = 10
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [704]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [705]: %timeit interaction(s,alpha,kprop)
100 loops, best of 3: 7.6 ms per loop

In [706]: %timeit vectorized_interaction(s,alpha,kprop)
1000 loops, best of 3: 304 µs per loop

Case #2:

In [707]: N = 100
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [708]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [709]: %timeit interaction(s,alpha,kprop)
1 loops, best of 3: 826 ms per loop

In [710]: %timeit vectorized_interaction(s,alpha,kprop)
100 loops, best of 3: 14 ms per loop

Case #3:

In [711]: N = 900
     ...: s = np.random.rand(N,3) + complex(0,1)*np.random.rand(N,3)
     ...: alpha = 3j
     ...: kprop = 5.4
     ...: 

In [712]: out_org = interaction(s,alpha,kprop)
     ...: out_vect = vectorized_interaction(s,alpha,kprop)
     ...: print np.allclose(np.real(out_org),np.real(out_vect))
     ...: print np.allclose(np.imag(out_org),np.imag(out_vect))
     ...: 
True
True

In [713]: %timeit interaction(s,alpha,kprop)
1 loops, best of 3: 1min 7s per loop

In [714]: %timeit vectorized_interaction(s,alpha,kprop)
1 loops, best of 3: 1.59 s per loop
like image 76
Divakar Avatar answered Nov 05 '22 20:11

Divakar