Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is it possible to access estimator attributes in spark.ml pipelines?

I have a spark.ml pipeline in Spark 1.5.1 which consists of a series of transformers followed by a k-means estimator. I want to be able to access the KMeansModel.clusterCenters after fitting the pipeline, but can't figure out how. Is there a spark.ml equivalent of sklearn's pipeline.named_steps feature?

I found this answer which gives two options. The first works if I take the k-means model out of my pipeline and fit it separately, but that kinda defeats the purpose of a pipeline. The second option doesn't work - I get error: value getModel is not a member of org.apache.spark.ml.PipelineModel.

EDIT: Example pipeline:

import org.apache.spark.ml.feature.{HashingTF, IDF, Tokenizer}
import org.apache.spark.ml.clustering.{KMeans, KMeansModel}
import org.apache.spark.ml.Pipeline

// create example dataframe
val sentenceData = sqlContext.createDataFrame(Seq(
  ("Hi I heard about Spark"),
  ("I wish Java could use case classes"),
  ("K-means models are neat")
  )).toDF("sentence")

// initialize pipeline stages
val tokenizer = new Tokenizer().setInputCol("sentence").setOutputCol("words")
val hashingTF = new HashingTF().setInputCol("words").setOutputCol("features").setNumFeatures(20)
val kmeans = new KMeans()
val pipeline = new Pipeline().setStages(Array(tokenizer, hashingTF, kmeans))

// fit the pipeline
val fitKmeans = pipeline.fit(sentenceData)

So now fitKmeans is of type org.apache.spark.ml.PipelineModel. My question is, how do I access the cluster centers calculated by the k-means model contained within this pipeline? As noted above, when not contained in a pipeline, this can be done with fitKmeans.clusterCenters.

like image 651
hilarious Avatar asked Oct 19 '15 17:10

hilarious


1 Answers

Answering my own question...I finally stumbled on an example deep in the spark.ml docs that shows how to do this using the stages member of the PipelineModel class. So for the example I posted above, in order to access the k-means cluster centers, do:

val centers = fitKmeans.stages(2).asInstanceOf[KMeansModel].clusterCenters

where fitKmeans is a PipelineModel and 2 is the index of the k-means model in the array of pipeline stages.

Reference: the last line of most of the examples on this page.

like image 75
hilarious Avatar answered Oct 04 '22 10:10

hilarious