Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Speeding up distance between all possible pairs in an array

I have an array of x,y,z coordinates of several (~10^10) points (only 5 shown here)

a= [[ 34.45  14.13   2.17]
    [ 32.38  24.43  23.12]
    [ 33.19   3.28  39.02]
    [ 36.34  27.17  31.61]
    [ 37.81  29.17  29.94]]

I want to make a new array with only those points which are at least some distance d away from all other points in the list. I wrote a code using while loop,

 import numpy as np
 from scipy.spatial import distance 

 d=0.1 #or some distance 
 i=0
 selected_points=[]
 while i < len(a):
          interdist=[]  
          j=i+1
          while j<len(a):
              interdist.append(distance.euclidean(a[i],a[j]))
              j+=1

          if all(dis >= d for dis in interdist):
              np.array(selected_points.append(a[i]))
          i+=1

This works, but it is taking really long to perform this calculation. I read somewhere that while loops are very slow.

I was wondering if anyone has any suggestions on how to speed up this calculation.

EDIT: While my objective of finding the particles which are at least some distance away from all the others stays the same, I just realized that there is a serious flaw in my code, let's say I have 3 particles, my code does the following, for the first iteration of i, it calculates the distances 1->2, 1->3, let's say 1->2 is less than the threshold distance d, so the code throws away particle 1. For the next iteration of i, it only does 2->3, and let's say it finds that it is greater than d, so it keeps particle 2, but this is wrong! since 2 should also be discarded with particle 1. The solution by @svohara is the correct one!

like image 794
HuShu Avatar asked Feb 26 '16 19:02

HuShu


1 Answers

For big data sets and low-dimensional points (such as your 3-dimensional data), sometimes there is a big benefit to using a spatial indexing method. One popular choice for low-dimensional data is the k-d tree.

The strategy is to index the data set. Then query the index using the same data set, to return the 2-nearest neighbors for each point. The first nearest neighbor is always the point itself (with dist=0), so we really want to know how far away the next closest point is (2nd nearest neighbor). For those points where the 2-NN is > threshold, you have the result.

from scipy.spatial import cKDTree as KDTree
import numpy as np

#a is the big data as numpy array N rows by 3 cols
a = np.random.randn(10**8, 3).astype('float32')

# This will create the index, prepare to wait...
# NOTE: took 7 minutes on my mac laptop with 10^8 rand 3-d numbers
#  there are some parameters that could be tweaked for faster indexing,
#  and there are implementations (not in scipy) that can construct
#  the kd-tree using parallel computing strategies (GPUs, e.g.)
k = KDTree(a)

#ask for the 2-nearest neighbors by querying the index with the
# same points
(dists, idxs) = k.query(a, 2)
# (dists, idxs) = k.query(a, 2, n_jobs=4)  # to use more CPUs on query...

#Note: 9 minutes for query on my laptop, 2 minutes with n_jobs=6
# So less than 10 minutes total for 10^8 points.

# If the second NN is > thresh distance, then there is no other point
# in the data set closer.
thresh_d = 0.1   #some threshold, equiv to 'd' in O.P.'s code
d_slice = dists[:, 1]  #distances to second NN for each point
res = np.flatnonzero( d_slice >= thresh_d )
like image 89
svohara Avatar answered Sep 22 '22 00:09

svohara