I've been working on speeding up a resampling calculation for a particle filter. As python has many ways to speed it up, I though I'd try them all. Unfortunately, the numba version is incredibly slow. As Numba should result in a speed up, I assume this is an error on my part.
I tried 4 different versions:
The code for each is below:
import numpy as np
import scipy as sp
import numba as nb
from cython_resample import cython_resample
@nb.autojit
def numba_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def python_resample(qs, xs, rands):
n = qs.shape[0]
lookup = np.cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
def numpy_resample(qs, xs, rands):
results = np.empty_like(qs)
lookup = sp.cumsum(qs)
for j, key in enumerate(rands):
i = sp.argmax(lookup>key)
results[j] = xs[i]
return results
#The following is the code for the cython module. It was compiled in a
#separate file, but is included here to aid in the question.
"""
import numpy as np
cimport numpy as np
cimport cython
DTYPE = np.float64
ctypedef np.float64_t DTYPE_t
@cython.boundscheck(False)
def cython_resample(np.ndarray[DTYPE_t, ndim=1] qs,
np.ndarray[DTYPE_t, ndim=1] xs,
np.ndarray[DTYPE_t, ndim=1] rands):
if qs.shape[0] != xs.shape[0] or qs.shape[0] != rands.shape[0]:
raise ValueError("Arrays must have same shape")
assert qs.dtype == xs.dtype == rands.dtype == DTYPE
cdef unsigned int n = qs.shape[0]
cdef unsigned int i, j
cdef np.ndarray[DTYPE_t, ndim=1] lookup = np.cumsum(qs)
cdef np.ndarray[DTYPE_t, ndim=1] results = np.zeros(n, dtype=DTYPE)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
"""
if __name__ == '__main__':
n = 100
xs = np.arange(n, dtype=np.float64)
qs = np.array([1.0/n,]*n)
rands = np.random.rand(n)
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Python Function:"
%timeit python_resample(qs, xs, rands)
print "Timing Numpy Function:"
%timeit numpy_resample(qs, xs, rands)
print "Timing Cython Function:"
%timeit cython_resample(qs, xs, rands)
This results in the following output:
Timing Numba Function:
1 loops, best of 3: 8.23 ms per loop
Timing Python Function:
100 loops, best of 3: 2.48 ms per loop
Timing Numpy Function:
1000 loops, best of 3: 793 µs per loop
Timing Cython Function:
10000 loops, best of 3: 25 µs per loop
Any idea why the numba code is so slow? I assumed it would be at least comparable to Numpy.
Note: if anyone has any ideas on how to speed up either the Numpy or Cython code samples, that would be nice too:) My main question is about Numba though.
Numba uses function decorators to increase the speed of functions. It is important that the user must enclose the computations inside a function. The most widely used decorator used in numba is the @jit decorator. Using this decorator, you can mark a function for optimization by Numba's JIT compiler.
For the uninitiated Numba is an open-source JIT compiler that translates a subset of Python/NumPy code into an optimized machine code using the LLVM compiler library. In short Numba makes Python/NumPy code runs faster.
Numba can speed things up Numba is a just-in-time compiler for Python specifically focused on code that runs in loops over NumPy arrays.
Numba compiled algorithms may make the runtime of the Python codes up to a million times faster and thus may reach the speed of C. In addition, with the increasing number of operations, the computation time is usually significantly faster than Cython, the other compiler used for faster processing.
The problem is that numba can't intuit the type of lookup
. If you put a print nb.typeof(lookup)
in your method, you'll see that numba is treating it as an object, which is slow. Normally I would just define the type of lookup
in a locals dict, but I was getting a strange error. Instead I just created a little wrapper, so that I could explicitly define the input and output types.
@nb.jit(nb.f8[:](nb.f8[:]))
def numba_cumsum(x):
return np.cumsum(x)
@nb.autojit
def numba_resample2(qs, xs, rands):
n = qs.shape[0]
#lookup = np.cumsum(qs)
lookup = numba_cumsum(qs)
results = np.empty(n)
for j in range(n):
for i in range(n):
if rands[j] < lookup[i]:
results[j] = xs[i]
break
return results
Then my timings are:
print "Timing Numba Function:"
%timeit numba_resample(qs, xs, rands)
print "Timing Revised Numba Function:"
%timeit numba_resample2(qs, xs, rands)
Timing Numba Function:
100 loops, best of 3: 8.1 ms per loop
Timing Revised Numba Function:
100000 loops, best of 3: 15.3 µs per loop
You can go even a little faster still if you use jit
instead of autojit
:
@nb.jit(nb.f8[:](nb.f8[:], nb.f8[:], nb.f8[:]))
For me that lowers it from 15.3 microseconds to 12.5 microseconds, but it's still impressive how well autojit does.
Faster numpy
version (10x speedup compared to numpy_resample
)
def numpy_faster(qs, xs, rands):
lookup = np.cumsum(qs)
mm = lookup[None,:]>rands[:,None]
I = np.argmax(mm,1)
return xs[I]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With