Given a list of points, how can I get their indices in a KDTree?
from scipy import spatial
import numpy as np
#some data
x, y = np.mgrid[0:3, 0:3]
data = zip(x.ravel(), y.ravel())
points = [[0,1], [2,2]]
#KDTree
tree = spatial.cKDTree(data)
# incices of points in tree should be [1,8]
I could do something like:
[tree.query_ball_point(i,r=0) for i in points]
>>> [[1], [8]]
Does it make sense to do it that way?
Use cKDTree.query(x, k, ...)
to find the k nearest neighbours to a given set of points x
:
distances, indices = tree.query(points, k=1)
print(repr(indices))
# array([1, 8])
In a trivial case such as this, where your dataset and your set of query points are both small, and where each query point is identical to a single row within the dataset, it would be faster to use simple boolean operations with broadcasting rather than building and querying a k-D tree:
data, points = np.array(data), np.array(points)
indices = (data[..., None] == points.T).all(1).argmax(0)
data[..., None] == points.T
broadcasts out to an (nrows, ndims, npoints)
array, which could quickly become expensive in terms of memory for larger datasets. In such cases you might get better performance out of a normal for
loop or list comprehension:
indices = [(data == p).all(1).argmax() for p in points]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With