Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pyspark: K means result with distance or deviation?

Tags:

pyspark

From https://spark.apache.org/docs/2.2.0/ml-clustering.html#k-means

I know that after kmModel.transform(df), there is a prediction column of the output dataframe stating which column a record/point belongs to.

However, I'd also like tho know how each record/point deviate from the centroid, so I know what points in this cluster are typical, and what may stand between clusters

How can I do it? It seems not implemented by the package by default

Thanks!

like image 629
cqcn1991 Avatar asked Sep 05 '18 04:09

cqcn1991


People also ask

How does k-means calculate distance?

In K-Means algorithm, we calculate the distance between each point of the dataset to every centroid initialized. Based on the values found, points are assigned to the centroid with minimum distance. Hence, this distance calculation plays the vital role in the clustering algorithm.

What k-means distance?

K means is a heuristic algorithm that partitions a data set into K clusters by minimizing the sum of squared distance in each cluster. During the implementation of k-means with three different distance metrics, it is observed that selection of distance metric plays a very important role in clustering.

How do you measure performance of K means clustering?

We need to calculate SSE to evaluate K-Means clustering using Elbow Criterion. The idea of the Elbow Criterion method is to choose the k (no of cluster) at which the SSE decreases abruptly. The SSE is defined as the sum of the squared distance between each member of the cluster and its centroid.


1 Answers

Let's assume we have the following sample data and kmeans model :

from pyspark.ml.linalg import Vectors
from pyspark.ml.clustering import KMeans
import pyspark.sql.functions as F

data = [(Vectors.dense([0.0, 0.0]),), (Vectors.dense([1.0, 1.0]),),
        (Vectors.dense([9.0, 8.0]),), (Vectors.dense([8.0, 9.0]),),
        (Vectors.dense([10.0, 1.5]),), (Vectors.dense([11, 0.0]),) ]
df = spark.createDataFrame(data, ["features"])

n_centres = 2
kmeans = KMeans().setK(n_centres).setSeed(1)
kmModel = kmeans.fit(df)
df_pred = kmModel.transform(df)
df_pred.show()

+----------+----------+
|  features|prediction|
+----------+----------+
| [0.0,0.0]|         1|
| [1.0,1.0]|         1|
| [9.0,8.0]|         0|
| [8.0,9.0]|         0|
|[10.0,1.5]|         0|
|[11.0,0.0]|         0|
+----------+----------+

Now, let's add a column containing the centers' coordinate :

l_clusters = kmModel.clusterCenters()
# Let's convert the list of centers to a dict, each center is a list of float
d_clusters = {int(i):[float(l_clusters[i][j]) for j in range(len(l_clusters[i]))] 
              for i in range(len(l_clusters))}

# Let's create a dataframe containing the centers and their coordinates
df_centers = spark.sparkContext.parallelize([(k,)+(v,) for k,v in 
d_clusters.items()]).toDF(['prediction','center'])

df_pred = df_pred.withColumn('prediction',F.col('prediction').cast(IntegerType()))
df_pred = df_pred.join(df_centers,on='prediction',how='left')
df_pred.show()


+----------+----------+------------+
|prediction|  features|      center|
+----------+----------+------------+
|         0| [8.0,9.0]|[9.5, 4.625]|
|         0|[10.0,1.5]|[9.5, 4.625]|
|         0| [9.0,8.0]|[9.5, 4.625]|
|         0|[11.0,0.0]|[9.5, 4.625]|
|         1| [1.0,1.0]|  [0.5, 0.5]|
|         1| [0.0,0.0]|  [0.5, 0.5]|
+----------+----------+------------+

Finally we can use a udf to compute the distance between the column features and center coordinates :

get_dist = F.udf(lambda features, center : 
                 float(features.squared_distance(center)),FloatType())
df_pred = df_pred.withColumn('dist',get_dist(F.col('features'),F.col('center')))
df_pred.show()

+----------+----------+------------+---------+
|prediction|  features|      center|     dist|
+----------+----------+------------+---------+
|         0|[11.0,0.0]|[9.5, 4.625]|23.640625|
|         0| [9.0,8.0]|[9.5, 4.625]|11.640625|
|         0| [8.0,9.0]|[9.5, 4.625]|21.390625|
|         0|[10.0,1.5]|[9.5, 4.625]|10.015625|
|         1| [1.0,1.0]|  [0.5, 0.5]|      0.5|
|         1| [0.0,0.0]|  [0.5, 0.5]|      0.5|
+----------+----------+------------+---------+
like image 81
plalanne Avatar answered Oct 16 '22 13:10

plalanne