Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark KMeans clustering: get the number of sample assigned to a cluster

I am using Spark Mlib for kmeans clustering. I have a set of vectors from which I want to determine the most likely cluster center. So I will run kmeans clustering training on this set and select the cluster with the highest number of vector assigned to it.

Therefore I need to know the number of vectors assigned to each cluster after training (i.e KMeans.run(...)). But I can not find a way to retrieve this information from KMeanModel result. I probably need to run predict on all training vectors and count the label which appear the most.

Is there another way to do this?

Thank you

like image 595
Khue Vu Avatar asked Oct 18 '22 23:10

Khue Vu


1 Answers

You are right, this info is not provided by the model, and you have to run predict. Here is an example of doing so in a parallelized way (Spark v. 1.5.1):

 from pyspark.mllib.clustering import KMeans
 from numpy import array
 data = array([0.0,0.0, 1.0,1.0, 9.0,8.0, 8.0,9.0, 10.0, 9.0]).reshape(5, 2)
 data
 # array([[  0.,   0.],
 #       [  1.,   1.],
 #       [  9.,   8.],
 #       [  8.,   9.],
 #       [ 10.,   9.]])

 k = 2 # no. of clusters
 model = KMeans.train(
                sc.parallelize(data), k, maxIterations=10, runs=30, initializationMode="random",
                seed=50, initializationSteps=5, epsilon=1e-4)

 cluster_ind = model.predict(sc.parallelize(data))
 cluster_ind.collect()
 # [1, 1, 0, 0, 0]

cluster_ind is an RDD of the same cardinality with our initial data, and it shows which cluster each datapoint belongs to. So, here we have two clusters, one with 3 datapoints (cluster 0) and one with 2 datapoints (cluster 1). Notice that we have run the prediction method in a parallel fashion (i.e. on an RDD) - collect() is used here only for our demonstration purposes, and it is not needed in a 'real' situation.

Now, we can get the cluster sizes with

 cluster_sizes = cluster_ind.countByValue().items()
 cluster_sizes
 # [(0, 3), (1, 2)]

From this, we can get the maximum cluster index & size as

 from operator import itemgetter
 max(cluster_sizes, key=itemgetter(1))
 # (0, 3)

i.e. our biggest cluster is cluster 0, with a size of 3 datapoints, which can be easily verified by inspection of cluster_ind.collect() above.

like image 192
desertnaut Avatar answered Oct 21 '22 20:10

desertnaut