Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scikit-learn - user-defined weights function for KNeighborsClassifier

I have a KNeighborsClassifier which classifies data based on 4 attributes. I'd like to weight those 4 attributes manually but always run into "operands could not be broadcast together with shapes (1,5) (4)".

There is very little documentation on weights : [callable] : a user-defined function which accepts an array of distances, and returns an array of the same shape containing the weights.(from here)

This is what I have for now :

    for v in result:
        params = [v['a_one'], v['a_two'], v['a_three'], v['a_four']]
        self.training_data['data'].append(params)
        self.training_data['target'].append(v['answer'])

    def get_weights(array_weights):
        return [1,1,2,1]

    classifier = neighbors.KNeighborsClassifier(weights=get_weights)
like image 442
L-R Avatar asked Oct 13 '25 08:10

L-R


1 Answers

Explanation of the sklearn weights callable

import numpy as np
from sklearn.neighbors import KNeighborsClassifier

Create sample data for model training

df = pd.DataFrame({'feature1':[1,3,3,4,5], 'response':[1,1,1,2,2]})

y = df.response
# [1,1,1,2,2]

X_train = df[['feature1']]
# [1,3,3,4,5]

Define a custom distance function (print input data structure)

def my_distance(weights):
    print(weights)
    return weights

Define model passing in my_distance as a callable to weights

knn = KNeighborsClassifier(n_neighbors=3, weights=my_distance)

knn.fit(X_train,y)

knn.predict([[1]])
# [[ 0.  2.  2.]]
# array([1])

Explanation: display the 3 closest neighbors (n_neighbors=3) to the predicted value of 1

The three closest neighbors to 1 in X_train:

1, 3, 3 

The distances:

[[ 0.  2.  2.]]

1 - 1 = 0 
3 - 1 = 2
3 - 1 = 2

The predicted class:

array([1])
like image 135
Alex Sherman Avatar answered Oct 16 '25 06:10

Alex Sherman



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!