Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to write a fast log-sum-exp in Cython and Weave?

I am looking at options to accelerate the log-sum-exp (using the "max trick") operation from Python code. I am on Windows 8 using Python 2.7. I have put together a comparison of implementations using Numpy, Scipy's implementation, Numba, Cython, Weave and numexpr, which can be viewed here on nbviewer.

I had expected my Cython and Weave versions to be the fastest of all, as they're nearest to the native code. But in fact, they're slower than my other versions.

How to make these versions as fast as possible ?

Edit: wrt initial notebook, added max trick in all methods to make comparison less trivial, and nearer to my actual need.

like image 313
Sebastien Avatar asked Nov 25 '13 14:11

Sebastien


1 Answers

An explicitly vectorized (SSE) c version is about 2.5x faster than any of the alternatives that you posted on my machine (~360 us vs 150 us), for float32 data. I don't have numba so I couldn't try that.

http://nbviewer.ipython.org/github/rmcgibbo/logsumexp/blob/master/Accelerating%20log-sum-exp.ipynb

Note, this is only with float32. One of the disadvantages of explicit SSE code is that it's very datatype specific, and I didn't take the effort to write a double precision version.

The full source code for the SSE implementation (BSD), with a simple setup.py installer is at https://github.com/rmcgibbo/logsumexp/tree/master

%timeit scipy.misc.logsumexp(a)
10.4467
1000 loops, best of 3: 363 µs per loop
10.4467144498
%timeit lse_weave(a)
1000 loops, best of 3: 352 µs per loop
10.4467
%timeit lse_numexpr(a)
1000 loops, best of 3: 360 µs per loop
10.4467162773
%timeit lse_cython(a)
1000 loops, best of 3: 361 µs per loop
10.4467163086
%timeit sselogsumexp.logsumexp(a)  # <--- my version
10000 loops, best of 3: 149 µs per loop
like image 134
Robert T. McGibbon Avatar answered Sep 29 '22 10:09

Robert T. McGibbon