Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Numba code slower than pure python

I've been working on speeding up a resampling calculation for a particle filter. As python has many ways to speed it up, I though I'd try them all. Unfortunately, the numba version is incredibly slow. As Numba should result in a speed up, I assume this is an error on my part.

I tried 4 different versions:

  1. Numba
  2. Python
  3. Numpy
  4. Cython

The code for each is below:

import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample

@nb.autojit
def numba_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def python_resample(qs, xs, rands):
    n = qs.shape[0]
    lookup = np.cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

def numpy_resample(qs, xs, rands):
    results = np.empty_like(qs)
    lookup = sp.cumsum(qs)
    for j, key in enumerate(rands):
        i = sp.argmax(lookup>key)
        results[j] = xs[i]
    return results

#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython

DTYPE = np.float64

ctypedef np.float64_t DTYPE_t

@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs, 
             np.ndarray[DTYPE_t, ndim=1] xs, 
             np.ndarray[DTYPE_t, ndim=1] rands):
    if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
        raise ValueError("Arrays must have same shape")
    assert qs.dtype == xs.dtype == rands.dtype == DTYPE

    cdef unsigned int n = qs.shape[0]
    cdef unsigned int i, j 
    cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
    cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results
"""

if __name__ == '__main__':
    n = 100
    xs = np.arange(n, dtype=np.float64)
    qs = np.array([1.0/n,]*n)
    rands = np.random.rand(n)

    print "Timing Numba Function:"
    %timeit numba_resample(qs, xs, rands)
    print "Timing Python Function:"
    %timeit python_resample(qs, xs, rands)
    print "Timing Numpy Function:"
    %timeit numpy_resample(qs, xs, rands)
    print "Timing Cython Function:"
    %timeit cython_resample(qs, xs, rands)

This results in the following output:

Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop

Any idea why the numba code is so slow? I assumed it would be at least comparable to Numpy.

Note: if anyone has any ideas on how to speed up either the Numpy or Cython code samples, that would be nice too:) My main question is about Numba though.

like image 918
jiminy_crist Avatar asked Jan 30 '14 21:01

jiminy_crist


People also ask

How can I speed up my Numba code?

Numba uses function decorators to increase the speed of functions. It is important that the user must enclose the computations inside a function. The most widely used decorator used in numba is the @jit decorator. Using this decorator, you can mark a function for optimization by Numba's JIT compiler.

Is Numba faster than NumPy?

For the uninitiated Numba is an open-source JIT compiler that translates a subset of Python/NumPy code into an optimized machine code using the LLVM compiler library. In short Numba makes Python/NumPy code runs faster.

Does Numba speed up for loops?

Numba can speed things up Numba is a just-in-time compiler for Python specifically focused on code that runs in loops over NumPy arrays.

Why is Numba faster than Cython?

Numba compiled algorithms may make the runtime of the Python codes up to a million times faster and thus may reach the speed of C. In addition, with the increasing number of operations, the computation time is usually significantly faster than Cython, the other compiler used for faster processing.


2 Answers

The problem is that numba can't intuit the type of lookup. If you put a print nb.typeof(lookup) in your method, you'll see that numba is treating it as an object, which is slow. Normally I would just define the type of lookup in a locals dict, but I was getting a strange error. Instead I just created a little wrapper, so that I could explicitly define the input and output types.

@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
    return np.cumsum(x)

@nb.autojit
def numba_resample2(qs, xs, rands):
    n = qs.shape[0]
    #lookup = np.cumsum(qs)
    lookup = numba_cumsum(qs)
    results = np.empty(n)

    for j in range(n):
        for i in range(n):
            if rands[j] < lookup[i]:
                results[j] = xs[i]
                break
    return results

Then my timings are:

print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)

print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)

Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop

You can go even a little faster still if you use jit instead of autojit:

@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))

For me that lowers it from 15.3 microseconds to 12.5 microseconds, but it's still impressive how well autojit does.

like image 161
JoshAdel Avatar answered Sep 25 '22 16:09

JoshAdel


Faster numpy version (10x speedup compared to numpy_resample)

def numpy_faster(qs, xs, rands):
    lookup = np.cumsum(qs)
    mm = lookup[None,:]>rands[:,None]
    I = np.argmax(mm,1)
    return xs[I]
like image 20
hpaulj Avatar answered Sep 25 '22 16:09

hpaulj