Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save RandomForestClassifier Spark model in scala?

I built a random forest model using the following code:

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier
val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("features")
val labelConverter = new    IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)
val training = labelIndexer.transform(df)
val model = rf.fit(training)

now I want to save the model in order to predict later using the following code:

val predictions: DataFrame = model.transform(testData)

I've looked into Spark documentation here and didn't find any option to do that. Any idea? It took me a few hours to build the model , if Spark is crushing I won't be able to get it back.

like image 569
Yaeli778 Avatar asked May 18 '16 13:05

Yaeli778


2 Answers

It's possible to save and reload tree based models in HDFS using Spark 1.6 using saveAsObjectFile() for both Pipeline based and basic model. Below is example for pipeline based model.

// model
val model = pipeline.fit(trainingData)

// Create rdd using Seq 
sc.parallelize(Seq(model), 1).saveAsObjectFile("hdfs://filepath")

// Reload model by using it's class
// You can get class of object using object.getClass()
val sameModel = sc.objectFile[PipelineModel]("filepath").first()
like image 64
manohar amrutkar Avatar answered Sep 20 '22 12:09

manohar amrutkar


For RandomForestClassifier save & load model: tested spark 1.6.2 + scala in ml(in spark 2.0 you can have direct save option for model)

import org.apache.spark.ml.classification.RandomForestClassificationModel
import org.apache.spark.ml.classification.RandomForestClassifier //imports
val classifier = new RandomForestClassifier().setImpurity("gini").setMaxDepth(3).setNumTrees(20).setFeatureSubsetStrategy("auto").setSeed(5043)
val model = classifier.fit(trainingData)

sc.parallelize(Seq(model), 1).saveAsObjectFile(modelSavePath) //save model

val linRegModel = sc.objectFile[RandomForestClassificationModel](modelSavePath).first() //load model
`val predictions1 = linRegModel.transform(testData)` //predictions1  is dataframe 
like image 30
Mentya Avatar answered Sep 22 '22 12:09

Mentya