Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

nearest neighbour search kdTree

To a list of N points [(x_1,y_1), (x_2,y_2), ... ] I am trying to find the nearest neighbours to each point based on distance. My dataset is too large to use a brute force approach so a KDtree seems best.

Rather than implement one from scratch I see that sklearn.neighbors.KDTree can find the nearest neighbours. Can this be used to find the nearest neighbours of each particle, i.e return a dim(N) list?

like image 482
RedPen Avatar asked Jan 06 '18 11:01

RedPen


2 Answers

This question is very broad and missing details. It's unclear what you did try, how your data looks like and what a nearest-neighbor is (identity?).

Assuming you are not interested in the identity (with distance 0), you can query the two nearest-neighbors and drop the first column. This is probably the easiest approach here.

Code:

 import numpy as np
 from sklearn.neighbors import KDTree
 np.random.seed(0)
 X = np.random.random((5, 2))  # 5 points in 2 dimensions
 tree = KDTree(X)
 nearest_dist, nearest_ind = tree.query(X, k=2)  # k=2 nearest neighbors where k1 = identity
 print(X)
 print(nearest_dist[:, 1])    # drop id; assumes sorted -> see args!
 print(nearest_ind[:, 1])     # drop id 

Output

 [[ 0.5488135   0.71518937]
  [ 0.60276338  0.54488318]
  [ 0.4236548   0.64589411]
  [ 0.43758721  0.891773  ]
  [ 0.96366276  0.38344152]]
 [ 0.14306129  0.1786471   0.14306129  0.20869372  0.39536284]
 [2 0 0 0 1]
like image 72
sascha Avatar answered Sep 24 '22 09:09

sascha


You can use sklearn.neighbors.KDTree's query_radius() method, which returns a list of the indices of the nearest neighbours within some radius (as opposed to returning k nearest neighbours).

from sklearn.neighbors import KDTree

points = [(1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]

tree = KDTree(points, leaf_size=2)
all_nn_indices = tree.query_radius(points, r=1.5)  # NNs within distance of 1.5 of point
all_nns = [[points[idx] for idx in nn_indices] for nn_indices in all_nn_indices]
for nns in all_nns:
    print(nns)

Outputs:

[(1, 1), (2, 2)]
[(1, 1), (2, 2), (3, 3)]
[(2, 2), (3, 3), (4, 4)]
[(3, 3), (4, 4), (5, 5)]
[(4, 4), (5, 5)]

Note that each point includes itself in its list of nearest neighbours within the given radius. If you want to remove these identity points, the line computing all_nns can be changed to:

all_nns = [
    [points[idx] for idx in nn_indices if idx != i]
    for i, nn_indices in enumerate(all_nn_indices)
]

Resulting in:

[(2, 2)]
[(1, 1), (3, 3)]
[(2, 2), (4, 4)]
[(3, 3), (5, 5)]
[(4, 4)]
like image 25
scrpy Avatar answered Sep 20 '22 09:09

scrpy