Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

'KD tree' with custom distance metric

I want to use 'KDtree'(this is the best option. Other 'KNN' algorithms aren't optimal for my project) with custom distance metric. I checked some answers here for similar questions, and this should work...but doesn't.

distance_matrix is symetric as should be by definition:

array([[ 1.,  0.,  5.,  5.,  0.,  3.,  2.],
   [ 0.,  1.,  0.,  0.,  0.,  0.,  0.],
   [ 5.,  0.,  1.,  5.,  0.,  2.,  3.],
   [ 5.,  0.,  5.,  1.,  0.,  4.,  4.],
   [ 0.,  0.,  0.,  0.,  1.,  0.,  0.],
   [ 3.,  0.,  2.,  4.,  0.,  1.,  0.],
   [ 2.,  0.,  3.,  4.,  0.,  0.,  1.]])

I know my metric is not 'formally metric', but in documentation it says that my function has to be 'formally metric', only when I'm using 'ball tree'(under User-defined distance:). Here is my code:

from sklearn.neighbors import DistanceMetric
def dist(x, y):
    dist = 0
    for elt_x, elt_y in zip(x, y):
        dist += distance_matrix[elt_x, elt_y]
    return dist
X = np.array([[1,0], [1,2], [1,3]])
tree = KDtree(X, metric=dist)

I get this error:

NameError
Traceback (most recent call last)   
<ipython-input-27-b5fac7810091> in <module>()
  7     return dist
  8 X = np.array([[1,0], [1,2], [1,3]])
----> 9 tree = KDtree(X, metric=dist)
NameError: name 'KDtree' is not defined

I tried also:

from sklearn.neighbors import KDTree
def dist(x, y):
    dist = 0
    for elt_x, elt_y in zip(x, y):
        dist += distance_matrix[elt_x, elt_y]
    return dist
X = np.array([[1,0], [1,2], [1,3]])
tree = KDTree(X, metric=lambda a,b: dist(a,b))

I get this error:

ValueError
Traceback (most recent call last)   
<ipython-input-27-b5fac7810091> in <module>()
  7     return dist
  8 X = np.array([[1,0], [1,2], [1,3]])
----> 9 tree = KDtree(X, metric=dist)
ValueError: metric PyFuncDistance is not valid for KDTree

I also tried:

from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree',    metric=dist_metric)

I get following error:

ValueError                                Traceback (most recent call last)
<ipython-input-32-c78d02cacb5a> in <module>()
      1 from sklearn.neighbors import NearestNeighbors
----> 2 nbrs = NearestNeighbors(n_neighbors=1, algorithm='kd_tree',     metric=dist_metric)

/usr/local/lib/python3.5/dist-packages/sklearn/neighbors/unsupervised.py    in __init__(self, n_neighbors, radius, algorithm, leaf_size, metric, p, metric_params, n_jobs, **kwargs)
    121                           algorithm=algorithm,
    122                           leaf_size=leaf_size, metric=metric, p=p,
--> 123                           metric_params=metric_params,     n_jobs=n_jobs, **kwargs)

/usr/local/lib/python3.5/dist-packages/sklearn/neighbors/base.py in     _init_params(self, n_neighbors, radius, algorithm, leaf_size, metric, p, metric_params, n_jobs)
    138                 raise ValueError(
    139                     "kd_tree algorithm does not support callable     metric '%s'"
--> 140                     % metric)
     141         elif metric not in VALID_METRICS[alg_check]:
    142             raise ValueError("Metric '%s' not valid for algorithm     '%s'"

ValueError: kd_tree algorithm does not support callable metric '<function     dist_metric at 0x7f58c2b3fd08>'

I tried all other algorithms (auto, brute,...), but it puts out same error.

I have to use distance matrix for elements of vectors as element is code for characteristics, and 5 can be closer to 1 than is 3. What I need is to get top 3 neighbors(sorted from closest to furthest).

like image 424
spartan Avatar asked Dec 31 '17 12:12

spartan


1 Answers

Scikit-learn's KDTree does not support custom distance metrics. The BallTree does support custom distance metrics, but be careful: it is up to the user to make certain the provided metric is actually a valid metric: if it is not, the algorithm will happily return results of a query, but the results will be incorrect.

Also, you should be aware that using a custom Python function as a metric is generally too slow to be useful, because of the overhead of Python callbacks within the traversal of the tree.

like image 113
jakevdp Avatar answered Sep 24 '22 13:09

jakevdp