Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

python point indices in KDTree

Given a list of points, how can I get their indices in a KDTree?

from scipy import spatial
import numpy as np

#some data
x, y = np.mgrid[0:3, 0:3]
data = zip(x.ravel(), y.ravel())

points = [[0,1], [2,2]]

#KDTree
tree = spatial.cKDTree(data)

# incices of points in tree should be [1,8]

I could do something like:

[tree.query_ball_point(i,r=0) for i in points]

>>> [[1], [8]]

Does it make sense to do it that way?

like image 765
a.smiet Avatar asked Oct 31 '22 15:10

a.smiet


1 Answers

Use cKDTree.query(x, k, ...) to find the k nearest neighbours to a given set of points x:

distances, indices = tree.query(points, k=1)
print(repr(indices))
# array([1, 8])

In a trivial case such as this, where your dataset and your set of query points are both small, and where each query point is identical to a single row within the dataset, it would be faster to use simple boolean operations with broadcasting rather than building and querying a k-D tree:

data, points = np.array(data), np.array(points)
indices = (data[..., None] == points.T).all(1).argmax(0)

data[..., None] == points.T broadcasts out to an (nrows, ndims, npoints) array, which could quickly become expensive in terms of memory for larger datasets. In such cases you might get better performance out of a normal for loop or list comprehension:

indices = [(data == p).all(1).argmax() for p in points]
like image 104
ali_m Avatar answered Nov 09 '22 16:11

ali_m