Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Spark load model and continue training

I'm using Scala with Spark 2.0 to train a model with LinearRegression.

val lr = new LinearRegression()
  .setMaxIter(num_iter)
  .setRegParam(reg)
  .setStandardization(true)

val model = lr.fit(data)

this is working fine and I get good results. I saved the model and loaded it in another class to make some predictions:

val model = LinearRegressionModel.load("models/LRModel")
val result = model.transform(data).select("prediction")

Now I wanted to continue training the model with new data, so I saved the model and loaded it to continue the training.

Saving:

model.save("models/LRModel")
lr.save("models/LR")

Loading:

val lr = LinearRegression.load("models/LR")
val model = LinearRegressionModel.load("models/LRModel")

The Problem is, when I load the model, there is not fit or train function to continue the training. When I load the LinearRegression object it seems like it does not save the weights, only the parameters for the algorithm. I tested it by training the same data for the same number of iterations and the result was the exact same rootMeanSquaredError and it was definitely not converged at this point of learning. I also can not load the model into the LinearRegression, it results in a error:

Exception in thread "main" java.lang.NoSuchMethodException: org.apache.spark.ml.regression.LinearRegressionModel.<init>(java.lang.String)

So the question is, how do I get the LinearRegression object to use the saved LinearRegressionModel?

like image 570
Silu Avatar asked Sep 01 '16 13:09

Silu


1 Answers

You can use pipeline to save and load the machine learning models.

import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.ml.PipelineModel
val lr = new LinearRegression().setLabelCol("labesl").setFeaturesCol("features").setMaxIter(10).setRegParam(1.0).setElasticNetParam(1.0)

val pipeline = new Pipeline().setStages(Array(lr))

pipeline.fit(trainingData)

pipeline.write.overwrite().save("hdfs://.../spark/mllib/models/linearRegression");

val sameModel = PipelineModel.load("hdfs://...")

sameModel.transform(assembler).select("features", "labels", "prediction").show(
like image 166
Jose Miguel Avatar answered Oct 14 '22 05:10

Jose Miguel