Usually I'm able to match Numba's performance when using Cython. However, in this example I have failed to do so - Numba is about 4 times faster than my Cython's version.
Here the Cython-version:
%%cython -c=-march=native -c=-O3
cimport numpy as np
import numpy as np
cimport cython
@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where(double[::1] df):
cdef int i
cdef int n = len(df)
cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
for i in range(n):
if df[i]>0.5:
output[i] = 2.0*df[i]
else:
output[i] = df[i]
return output
And here is the Numba-version:
import numba as nb
@nb.njit
def nb_where(df):
n = len(df)
output = np.empty(n, dtype=np.float64)
for i in range(n):
if df[i]>0.5:
output[i] = 2.0*df[i]
else:
output[i] = df[i]
return output
When tested, the Cython version is on par with numpy's where
, but is clearly inferior to Numba:
#Python3.6 + Cython 0.28.3 + gcc-7.2
import numpy
np.random.seed(0)
n = 10000000
data = np.random.random(n)
assert (cy_where(data)==nb_where(data)).all()
assert (np.where(data>0.5,2*data, data)==nb_where(data)).all()
%timeit cy_where(data) # 179ms
%timeit nb_where(data) # 49ms (!!)
%timeit np.where(data>0.5,2*data, data) # 278 ms
What is the reason for Numba's performance and how can it be matched when using Cython?
As suggested by @max9111, eliminating stride by using continuous memory-view, which doesn't improve the performance much:
@cython.boundscheck(False)
@cython.wraparound(False)
def cy_where_cont(double[::1] df):
cdef int i
cdef int n = len(df)
cdef np.ndarray[dtype=double] output = np.empty(n, dtype=np.float64)
cdef double[::1] view = output # view as continuous!
for i in range(n):
if df[i]>0.5:
view[i] = 2.0*df[i]
else:
view[i] = df[i]
return output
%timeit cy_where_cont(data) # 165 ms
This seems to be completely driven by optimizations that LLVM is able to make. If I compile the cython example with clang, performance between the two examples is identical. For what it's worth, MSVC on windows shows a similar performance discrepancy to numba.
$ CC=clang ipython
<... setup code>
In [7]: %timeit cy_where(data) # 179ms
...: %timeit nb_where(data) # 49ms (!!)
30.8 ms ± 309 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
30.2 ms ± 498 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With