Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Distance between nodes and the centroid in a kmeans cluster?

Any option to extract the distance between the nodes and the centroid in a kmeans cluster.

I have done Kmeans clustering over an text embedding data set and I want to know which are the nodes that are far away from the Centroid in each of the cluster, so that I can check the respective node's features which is making a difference.

Thanks in advance!

like image 370
Arav Avatar asked Jan 17 '19 16:01

Arav


2 Answers

KMeans.transform() returns an array of distances of each sample to the cluster center.

import numpy as np

from sklearn.datasets import make_blobs
from sklearn.cluster import KMeans

import matplotlib.pyplot as plt
plt.style.use('ggplot')
import seaborn as sns

# Generate some random clusters
X, y = make_blobs()
kmeans = KMeans(n_clusters=3).fit(X)

# plot the cluster centers and samples 
sns.scatterplot(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1], 
                marker='+', 
                color='black', 
                s=200);
sns.scatterplot(X[:,0], X[:,1], hue=y, 
                palette=sns.color_palette("Set1", n_colors=3));

enter image description here

transform X and take the sum of each row (axis=1) to identify samples furthest from the centers.

# squared distance to cluster center
X_dist = kmeans.transform(X)**2

# do something useful...
import pandas as pd
df = pd.DataFrame(X_dist.sum(axis=1).round(2), columns=['sqdist'])
df['label'] = y

df.head()
    sqdist  label
0   211.12  0
1   257.58  0
2   347.08  1
3   209.69  0
4   244.54  0

A visual check -- the same plot, only this time with the furthest points to each cluster center highlighted:

# for each cluster, find the furthest point
max_indices = []
for label in np.unique(kmeans.labels_):
    X_label_indices = np.where(y==label)[0]
    max_label_idx = X_label_indices[np.argmax(X_dist[y==label].sum(axis=1))]
    max_indices.append(max_label_idx)

# replot, but highlight the furthest point
sns.scatterplot(kmeans.cluster_centers_[:,0], kmeans.cluster_centers_[:,1], 
                marker='+', 
                color='black', 
                s=200);
sns.scatterplot(X[:,0], X[:,1], hue=y, 
                palette=sns.color_palette("Set1", n_colors=3));
# highlight the furthest point in black
sns.scatterplot(X[max_indices, 0], X[max_indices, 1], color='black');

enter image description here

like image 199
Kevin Avatar answered Oct 21 '22 17:10

Kevin


If you are using Python and sklearn.

From here: https://scikit-learn.org/stable/modules/generated/sklearn.cluster.KMeans.html#sklearn.cluster.KMeans

you can get labels_ and cluster_centers_.

Now, you determine the distance function that takes the vector of each node and its cluster center. Filter by labels_ and calculate distances for each point inside each label.

like image 42
avchauzov Avatar answered Oct 21 '22 18:10

avchauzov