Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python, Pairwise 'distance', need a fast way to do it

For a side project in my PhD, I engaged in the task of modelling some system in Python. Efficiency wise, my program hits a bottleneck in the following problem, which I'll expose in a Minimal Working Example.

I deal with a large number of segments encoded by their 3D beginning and endpoints, so each segment is represented by 6 scalars.

I need to calculate a pairwise minimal intersegment distance. The analytical expression of the minimal distance between two segments is found in this source. To the MWE:

import numpy as np
N_segments = 1000
List_of_segments = np.random.rand(N_segments, 6)

Pairwise_minimal_distance_matrix = np.zeros( (N_segments,N_segments) )
for i in range(N_segments):
    for j in range(i+1,N_segments): 

        p0 = List_of_segments[i,0:3] #beginning point of segment i
        p1 = List_of_segments[i,3:6] #end point of segment i
        q0 = List_of_segments[j,0:3] #beginning point of segment j
        q1 = List_of_segments[j,3:6] #end point of segment j
        #for readability, some definitions
        a = np.dot( p1-p0, p1-p0)
        b = np.dot( p1-p0, q1-q0)
        c = np.dot( q1-q0, q1-q0)
        d = np.dot( p1-p0, p0-q0)
        e = np.dot( q1-q0, p0-q0)
        s = (b*e-c*d)/(a*c-b*b)
        t = (a*e-b*d)/(a*c-b*b)
        #the minimal distance between segment i and j
        Pairwise_minimal_distance_matrix[i,j] = sqrt(sum( (p0+(p1-p0)*s-(q0+(q1-q0)*t))**2)) #minimal distance

Now, I realize this is extremely inefficient, and this is why I am here. I have looked extensively in how to avoid the loop, but I run into a bit of a problem. Apparently, this sort of calculations is best done with the cdist of python. However, the custom distance functions it can handle have to be binary functions. This is a problem in my case, because my vectors have specifically a length of 6, and have to bit split into their first and last 3 components. I don't think I can translate the distance calculation into a binary function.

Any input is appreciated.

like image 549
Mathusalem Avatar asked Feb 24 '15 10:02

Mathusalem


2 Answers

You can use numpy's vectorization capabilities to speed up the calculation. My version computes all elements of the distance matrix at once and then sets the diagonal and the lower triangle to zero.

def pairwise_distance2(s):
    # we need this because we're gonna divide by zero
    old_settings = np.seterr(all="ignore")

    N = N_segments # just shorter, could also use len(s)

    # we repeat p0 and p1 along all columns
    p0 = np.repeat(s[:,0:3].reshape((N, 1, 3)), N, axis=1)
    p1 = np.repeat(s[:,3:6].reshape((N, 1, 3)), N, axis=1)
    # and q0, q1 along all rows
    q0 = np.repeat(s[:,0:3].reshape((1, N, 3)), N, axis=0)
    q1 = np.repeat(s[:,3:6].reshape((1, N, 3)), N, axis=0)

    # element-wise dot product over the last dimension,
    # while keeping the number of dimensions at 3
    # (so we can use them together with the p* and q*)
    a = np.sum((p1 - p0) * (p1 - p0), axis=-1).reshape((N, N, 1))
    b = np.sum((p1 - p0) * (q1 - q0), axis=-1).reshape((N, N, 1))
    c = np.sum((q1 - q0) * (q1 - q0), axis=-1).reshape((N, N, 1))
    d = np.sum((p1 - p0) * (p0 - q0), axis=-1).reshape((N, N, 1))
    e = np.sum((q1 - q0) * (p0 - q0), axis=-1).reshape((N, N, 1))

    # same as above
    s = (b*e-c*d)/(a*c-b*b)
    t = (a*e-b*d)/(a*c-b*b)

    # almost same as above
    pairwise = np.sqrt(np.sum( (p0 + (p1 - p0) * s - ( q0 + (q1 - q0) * t))**2, axis=-1))

    # turn the error reporting back on
    np.seterr(**old_settings)

    # set everything at or below the diagonal to 0
    pairwise[np.tril_indices(N)] = 0.0

    return pairwise

Now let's take it for a spin. With your example, N = 1000, I get a timing of

%timeit pairwise_distance(List_of_segments)
1 loops, best of 3: 10.5 s per loop

%timeit pairwise_distance2(List_of_segments)
1 loops, best of 3: 398 ms per loop

And of course, the results are the same:

(pairwise_distance2(List_of_segments) == pairwise_distance(List_of_segments)).all()

returns True. I'm also pretty sure there's a matrix multiplication hidden somewhere in the algorithm, so there should be some potential for further speedup (and also cleanup).

By the way: I've tried simply using numba first without success. Not sure why, though.

like image 91
Carsten Avatar answered Oct 10 '22 16:10

Carsten


This is more of a meta answer, at least for starters. Your problem might already be in "my program hits a bottleneck" and "I realize this is extremely inefficient".

Extremely inefficient? By what measure? Do you have comparison? Is your code too slow to finish in a reasonable amount of time? What is a reasonable amount of time for you? Can you throw more computing power at the problem? Equally important -- do you use a proper infrastructure to run your code on (numpy/scipy compiled with vendor compilers, possibly with OpenMP support)?

Then, if you have answers for all of the questions above and need to further optimize your code -- where is the bottleneck in your current code exactly? Did you profile it? It the body of the loop possibly much more heavy-weight than the evaluation of the loop itself? If so, then "the loop" is not your bottleneck and you do not need to worry about the nested loop in the first place. Optimize the body at first, possibly by coming up with unorthodox matrix representations of your data so that you can perform all these single calculations in one step -- by matrix multiplication, for instance. If your problem is not solvable by efficient linear algebra operations, you can start writing a C extension or use Cython or use PyPy (which just very recently got some basic numpy support!). There are endless possibilities for optimizing -- the questions really are: how close to a practical solution are you already, how much do you need to optimize, and how much of an effort are you willing to invest.

Disclaimer: I have done non-canonical pairwise-distance stuff with scipy/numpy for my PhD, too ;-). For one particular distance metric, I ended up coding the "pairwise" part in simple Python (i.e. I also used the doubly-nested loop), but spent some effort in getting the body as efficient as possible (with a combination of i) a cryptical matrix multiplication representation of my problem and ii) using bottleneck).

like image 45
Dr. Jan-Philip Gehrcke Avatar answered Oct 10 '22 16:10

Dr. Jan-Philip Gehrcke