I am using cython to compute a pairwise distance matrix using a custom metric as a faster alternative to scipy.spatial.distance.pdist.
My metric has the form
def mymetric(u,v,w):
np.sum(w * (1 - np.abs(np.abs(u - v) / np.pi - 1))**2)
and the pairwise distance using scipy can be computed as
x = sp.spatial.distance.pdist(r, metric=lambda u, v: mymetric(u, v, w))
Here, r
is a m
-by-n
matrix of m
vectors with dimension of n
and w
is a "weight" factor with dimmension n
.
Since in my problem m
is rather high, the computation is really slow. For m = 2000
and n = 10
this takes approx 20 sec.
I implemented a simple function in cython that computes the pairwise distance and immediately got very promising results -- speedup of over 500x.
import numpy as np
cimport numpy as np
import cython
from libc.math cimport fabs, M_PI
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size
cdef np.ndarray[np.double_t, ndim=1] ans
size = r.shape[0] * (r.shape[0] - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
c = -1
for i in range(r.shape[0]):
for j in range(i + 1, r.shape[0]):
c += 1
for k in range(r.shape[1]):
ans[c] += w[k] * (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))**2.0
return ans
I wanted to speed up the computation some more using OpenMP, however, the following solution is roughly 3 times slower than the serial version.
import numpy as np
cimport numpy as np
import cython
from cython.parallel import prange, parallel
cimport openmp
from libc.math cimport fabs, M_PI
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size, m, n
cdef np.double_t a
cdef np.ndarray[np.double_t, ndim=1] ans
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
with nogil, parallel(num_threads=8):
for i in prange(m, schedule='dynamic'):
for j in range(i + 1, m):
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
for k in range(n):
ans[c] += w[k] * (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))**2.0
return ans
I don't know why is it actually slower, but I tried to introduce the following changes. This resulted not only in even slightly worse performance but also, the resulting distance The speedup achieved through this is negligible. ans
is computed correctly only in the beginning of the array, the rest is just zeros.
import numpy as np
cimport numpy as np
import cython
from cython.parallel import prange, parallel
cimport openmp
from libc.math cimport fabs, M_PI
from libc.stdlib cimport malloc, free
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp_2(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int k, l, c, m, n
cdef Py_ssize_t i, j, d
cdef size_t size
cdef int *ci, *cj
cdef np.ndarray[np.double_t, ndim=1, mode="c"] ans
cdef np.ndarray[np.double_t, ndim=2, mode="c"] data
cdef np.ndarray[np.double_t, ndim=1, mode="c"] weight
data = np.ascontiguousarray(r, dtype=np.float64)
weight = np.ascontiguousarray(w, dtype=np.float64)
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
cj = <int*> malloc(size * sizeof(int))
ci = <int*> malloc(size * sizeof(int))
c = -1
for i in range(m):
for j in range(i + 1, m):
c += 1
ci[c] = i
cj[c] = j
with nogil, parallel(num_threads=8):
for d in prange(size, schedule='guided'):
for k in range(n):
ans[d] += weight[k] * (1.0 - fabs(fabs(data[ci[d], k] - data[cj[d], k]) / M_PI - 1.0))**2.0
return ans
For all functions, I am using the following .pyxbld
file
def make_ext(modname, pyxfilename):
from distutils.extension import Extension
return Extension(name=modname,
sources=[pyxfilename],
extra_compile_args=['-O3', '-march=native', '-ffast-math', '-fopenmp'],
extra_link_args=['-fopenmp'],
)
I have zero experience with cython and know only basics of C. I would appreciate any suggestion of what may be the cause of this unexpected behavior, or even, how to rephrase my question better.
@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_2(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size
cdef np.ndarray[np.double_t, ndim=1] ans
cdef np.double_t accumulator, tmp
size = r.shape[0] * (r.shape[0] - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
c = -1
for i in range(r.shape[0]):
for j in range(i + 1, r.shape[0]):
c += 1
accumulator = 0
for k in range(r.shape[1]):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator += w[k] * (tmp*tmp)
ans[c] = accumulator
return ans
@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp_2d(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size, m, n
cdef np.ndarray[np.double_t, ndim=1] ans
cdef np.double_t accumulator, tmp
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
with nogil, parallel(num_threads=8):
for i in prange(m, schedule='dynamic'):
for j in range(i + 1, m):
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(n):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
ans[c] += w[k] * (tmp*tmp)
return ans
When I try to apply the accumulator
solution proposed in the answer, I get the following error:
Error compiling Cython file:
------------------------------------------------------------
...
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(n):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator += w[k] * (tmp*tmp)
ans[c] = accumulator
^
------------------------------------------------------------
pdist.pyx:207:36: Cannot read reduction variable in loop body
Full code:
@cython.cdivision(True)
@cython.wraparound(False)
@cython.boundscheck(False)
def pairwise_distance_omp(np.ndarray[np.double_t, ndim=2] r, np.ndarray[np.double_t, ndim=1] w):
cdef int i, j, k, c, size, m, n
cdef np.ndarray[np.double_t, ndim=1] ans
cdef np.double_t accumulator, tmp
m = r.shape[0]
n = r.shape[1]
size = m * (m - 1) / 2
ans = np.zeros(size, dtype=r.dtype)
with nogil, parallel(num_threads=8):
for i in prange(m, schedule='dynamic'):
for j in range(i + 1, m):
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(n):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator += w[k] * (tmp*tmp)
ans[c] = accumulator
return ans
In this chapter we will learn about Cython's multithreading features to access thread-based parallelism. Our focus will be on the prange Cython function, which allows us to easily transform serial for loops to use multiple threads and tap into all available CPU cores.
Cython allows you to release the GIL. That means that you can do multi-threading in at least 2 ways: Directly in Cython, using OpenMP with prange. Using e.g. joblib with a multi-threading backend (the parts of your code that will be parallelized are the parts that release the GIL)
I haven't timed this myself so it's possible this might not help too much, however:
If you run cython -a
to get an annotated version of your initial attempt (pairwise_distance_omp
) you'll find the ans[c] += ...
line is yellow, suggesting it's got Python overhead. A look at that the C corresponding to that line suggests that it's checking for divide by zero. One key part of it starts:
if (unlikely(M_PI == 0)) {
You know this will never be true (and in any case you'd probably live with NaN values rather than an exception if it was). You can avoid this check by adding the following extra decorator to the function:
@cython.cdivision(True)
# other decorators
def pairwise_distance_omp # etc...
This cuts out quite a bit of C code, including bits that have to be run in a single thread. The flip-side is that most of that code should never be run, and the compiler should probably be able to work that out, so it isn't clear how much difference that will make.
Second suggestion:
# at the top
cdef np.double_t accumulator, tmp
# further down later in the loop:
c = i * (m - 1) - i * (i + 1) / 2 + j - 1
accumulator = 0
for k in range(r.shape[1]):
tmp = (1.0 - fabs(fabs(r[i, k] - r[j, k]) / M_PI - 1.0))
accumulator = accumulator + w[k] * (tmp*tmp)
ans[c] = accumulator
This has two advantages hopefully: 1) tmp*tmp
should probably be quicker than floating point exponent to the power of 2. 2) You avoid reading from the ans
array, which might be a bit slow because the compiler always has to be careful that some other thread hasn't changed it (even though you know it shouldn't have).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With