Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optimizing numpy.dot with Cython

I have the following piece of code which I'd like to optimize using Cython:

sim = numpy.dot(v1, v2) / (sqrt(numpy.dot(v1, v1)) * sqrt(numpy.dot(v2, v2))) 
dist = 1-sim
return dist

I have written and compiled the .pyx file and when I ran the code I do not see any significant improvement in performance. According to the Cython documentation I have to add c_types. The HTML file generated by Cython indicates that the bottleneck is the dot products (which is expected of course). Does this mean that I have to define a C function for the dot products? If yes how do I do that?

EDIT:

After some research I have come up with the following code. The improvement is only marginal. I am not sure if there is something I can do to improve it :

from __future__ import division
import numpy as np
import math as m
cimport numpy as np
cimport cython

cdef extern from "math.h":
    double c_sqrt "sqrt"(double)

ctypedef np.float reals #typedef_for easier readding

cdef inline double dot(np.ndarray[reals,ndim = 1] v1, np.ndarray[reals,ndim = 1] v2):
  cdef double result = 0
  cdef int i = 0
  cdef int length = v1.size
  cdef double el1 = 0
  cdef double el2 = 0
  for i in range(length):
    el1 = v1[i]
    el2 = v2[i]
    result += el1*el2
  return result

@cython.cdivision(True)
def distance(np.ndarray[reals,ndim = 1] ex1, np.ndarray[reals,ndim = 1] ex2):
  cdef double dot12 = dot(ex1, ex2)
  cdef double dot11 = dot(ex1, ex1)
  cdef double dot22 = dot(ex2, ex2)
  cdef double sim = dot12 / (c_sqrt(dot11 * dot22))
  cdef double dist = 1-sim    
  return dist 
like image 294
George Eracleous Avatar asked May 28 '12 17:05

George Eracleous


1 Answers

As a general note, if you are calling numpy functions from within cython and doing little else, you generally will see only marginal gains if any at all. You generally only get massive speed-ups if you are statically typing code that makes use of an explicit for loop at the python level (not in something that is calling the Numpy C-API already).

You could try writing out the code for a dot product with all of the static typing of the counter, input numpy arrays, etc, with wraparound and boundscheck set to False, import the clib version of the sqrt function and then try to leverage the parallel for loop (prange) to make use of openmp.

like image 159
JoshAdel Avatar answered Sep 17 '22 01:09

JoshAdel