When I use scikit-learn's implementation of k-means I usually just call the fit()
method and that is enough to get the cluster centers and the labels. The predict()
method is used to calculate the labels and even a fit_predict()
method is available for convenience, but if I can get the labels only using fit()
, what is the purpose of the predict()
method?
predict
, as @EdChum suggested, can be used on unseen data. This (and more so, the transform
method) is useful when k-means is used for feature extraction in semisupervised learning: you cluster a large set of samples, then use nearest centroid/distance to centroids as features for a subsequent supervised learning problem. When using the result for prediction, you get samples that were not seen by k-means.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With