Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to cross validate RandomForest model?

Tags:

I want to evaluate a random forest being trained on some data. Is there any utility in Apache Spark to do the same or do I have to perform cross validation manually?

like image 263
ashishsjsu Avatar asked Sep 24 '15 19:09

ashishsjsu


People also ask

Is GridSearchCV cross-validation?

Yes, GridSearchCV performs cross-validation. If I understand the concept correctly - you want to keep part of your data set unseen for the model in order to test it. So you train your models against train data set and test them on a testing data set.

Can we do cross-validation on random forest?

Yes, out-of-bag performance for a random forest is very similar to cross validation. Essentially what you get is leave-one-out with the surrogate random forests using fewer trees. So if done correctly, you get a slight pessimistic bias.


1 Answers

ML provides CrossValidator class which can be used to perform cross-validation and parameter search. Assuming your data is already preprocessed you can add cross-validation as follows:

import org.apache.spark.ml.Pipeline import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator} import org.apache.spark.ml.classification.RandomForestClassifier import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator  // [label: double, features: vector] trainingData org.apache.spark.sql.DataFrame = ???  val nFolds: Int = ??? val numTrees: Int = ??? val metric: String = ???  val rf = new RandomForestClassifier()   .setLabelCol("label")   .setFeaturesCol("features")   .setNumTrees(numTrees)  val pipeline = new Pipeline().setStages(Array(rf))   val paramGrid = new ParamGridBuilder().build() // No parameter search  val evaluator = new MulticlassClassificationEvaluator()   .setLabelCol("label")   .setPredictionCol("prediction")   // "f1" (default), "weightedPrecision", "weightedRecall", "accuracy"   .setMetricName(metric)   val cv = new CrossValidator()   // ml.Pipeline with ml.classification.RandomForestClassifier   .setEstimator(pipeline)   // ml.evaluation.MulticlassClassificationEvaluator   .setEvaluator(evaluator)    .setEstimatorParamMaps(paramGrid)   .setNumFolds(nFolds)  val model = cv.fit(trainingData) // trainingData: DataFrame 

Using PySpark:

from pyspark.ml import Pipeline from pyspark.ml.classification import RandomForestClassifier from pyspark.ml.tuning import CrossValidator, ParamGridBuilder from pyspark.ml.evaluation import MulticlassClassificationEvaluator  trainingData = ... # DataFrame[label: double, features: vector] numFolds = ... # Integer  rf = RandomForestClassifier(labelCol="label", featuresCol="features") evaluator = MulticlassClassificationEvaluator() # + other params as in Scala      pipeline = Pipeline(stages=[rf]) paramGrid = (ParamGridBuilder.      .addGrid(rf.numTrees, [3, 10])     .addGrid(...)  # Add other parameters     .build())  crossval = CrossValidator(     estimator=pipeline,     estimatorParamMaps=paramGrid,     evaluator=evaluator,     numFolds=numFolds)  model = crossval.fit(trainingData) 
like image 197
zero323 Avatar answered Oct 30 '22 08:10

zero323