Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to extract best parameters from a CrossValidatorModel

I want to find the parameters of ParamGridBuilder that make the best model in CrossValidator in Spark 1.4.x,

In Pipeline Example in Spark documentation, they add different parameters (numFeatures, regParam) by using ParamGridBuilder in the Pipeline. Then by the following line of code they make the best model:

val cvModel = crossval.fit(training.toDF)

Now, I want to know what are the parameters (numFeatures, regParam) from ParamGridBuilder that produces the best model.

I already used the following commands without success:

cvModel.bestModel.extractParamMap().toString()
cvModel.params.toList.mkString("(", ",", ")")
cvModel.estimatorParamMaps.toString()
cvModel.explainParams()
cvModel.getEstimatorParamMaps.mkString("(", ",", ")")
cvModel.toString()

Any help?

Thanks in advance,

like image 560
Mohammad Avatar asked Jul 31 '15 15:07

Mohammad


4 Answers

One method to get a proper ParamMap object is to use CrossValidatorModel.avgMetrics: Array[Double] to find the argmax ParamMap:

implicit class BestParamMapCrossValidatorModel(cvModel: CrossValidatorModel) {   def bestEstimatorParamMap: ParamMap = {     cvModel.getEstimatorParamMaps            .zip(cvModel.avgMetrics)            .maxBy(_._2)            ._1   } } 

When run on the CrossValidatorModel trained in the Pipeline Example you cited gives:

scala> println(cvModel.bestEstimatorParamMap) {    hashingTF_2b0b8ccaeeec-numFeatures: 100,    logreg_950a13184247-regParam: 0.1 } 
like image 91
Adam Vogel Avatar answered Oct 06 '22 02:10

Adam Vogel


val bestPipelineModel = cvModel.bestModel.asInstanceOf[PipelineModel]
val stages = bestPipelineModel.stages

val hashingStage = stages(1).asInstanceOf[HashingTF]
println("numFeatures = " + hashingStage.getNumFeatures)

val lrStage = stages(2).asInstanceOf[LogisticRegressionModel]
println("regParam = " + lrStage.getRegParam)

source

like image 34
macfeliga Avatar answered Oct 06 '22 01:10

macfeliga


To print everything in paramMap, you actually don't have to call parent:

cvModel.bestModel.extractParamMap()

To answer OP's question, to get a single best parameter, for example regParam:

cvModel.bestModel.extractParamMap().apply(cvModel.bestModel.getParam("regParam"))
like image 29
Algorithman Avatar answered Oct 06 '22 02:10

Algorithman


This is how you get the chosen parameters

println(cvModel.bestModel.getMaxIter)   
println(cvModel.bestModel.getRegParam)  
like image 44
Mazen Aly Avatar answered Oct 06 '22 02:10

Mazen Aly