Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RandomForestClassifier was given input with invalid label column error in Apache Spark

I am trying to find Accuracy using 5-fold cross validation using Random Forest Classifier Model in SCALA. But i am getting the following error while running:

java.lang.IllegalArgumentException: RandomForestClassifier was given input with invalid label column label, without the number of classes specified. See StringIndexer.

Getting the above error at line---> val cvModel = cv.fit(trainingData)

The code which i used for cross validation of data set using random forest is as follows:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint

val data = sc.textFile("exprogram/dataset.txt")
val parsedData = data.map { line =>
val parts = line.split(',')
LabeledPoint(parts(41).toDouble, 
Vectors.dense(parts(0).split(',').map(_.toDouble)))
}


val splits = parsedData.randomSplit(Array(0.6, 0.4), seed = 11L)
val training = splits(0)
val test = splits(1)

val trainingData = training.toDF()

val testData = test.toDF()

val nFolds: Int = 5
val NumTrees: Int = 5

val rf = new     
RandomForestClassifier()
      .setLabelCol("label")
      .setFeaturesCol("features")
      .setNumTrees(NumTrees)

val pipeline = new Pipeline()
      .setStages(Array(rf)) 

val paramGrid = new ParamGridBuilder()
          .build()

val evaluator = new  MulticlassClassificationEvaluator()
    .setLabelCol("label")
    .setPredictionCol("prediction")
    .setMetricName("precision") 

val cv = new CrossValidator()
   .setEstimator(pipeline)
   .setEvaluator(evaluator) 
   .setEstimatorParamMaps(paramGrid)
   .setNumFolds(nFolds)

val cvModel = cv.fit(trainingData)

val results = cvModel.transform(testData)
.select("label","prediction").collect

val numCorrectPredictions = results.map(row => 
if (row.getDouble(0) == row.getDouble(1)) 1 else 0).foldLeft(0)(_ + _)
val accuracy = 1.0D * numCorrectPredictions / results.size

println("Test set accuracy: %.3f".format(accuracy))

Can any one please explain what is the mistake in the above code.

like image 924
spiky Avatar asked Apr 09 '16 13:04

spiky


1 Answers

RandomForestClassifier, same as many other ML algorithms, require specific metadata to be set on the label column and labels values to be integral values from [0, 1, 2 ..., #classes) represented as doubles. Typically this is handled by an upstream Transformers like StringIndexer. Since you convert labels manually metadata fields are not set and classifier cannot confirm that these requirements are satisfied.

val df = Seq(
  (0.0, Vectors.dense(1, 0, 0, 0)),
  (1.0, Vectors.dense(0, 1, 0, 0)),
  (2.0, Vectors.dense(0, 0, 1, 0)),
  (2.0, Vectors.dense(0, 0, 0, 1))
).toDF("label", "features")

val rf = new RandomForestClassifier()
  .setFeaturesCol("features")
  .setNumTrees(5)

rf.setLabelCol("label").fit(df)
// java.lang.IllegalArgumentException: RandomForestClassifier was given input ...

You can either re-encode label column using StringIndexer:

import org.apache.spark.ml.feature.StringIndexer

val indexer = new StringIndexer()
  .setInputCol("label")
  .setOutputCol("label_idx")
  .fit(df)

rf.setLabelCol("label_idx").fit(indexer.transform(df))

or set required metadata manually:

val meta = NominalAttribute
  .defaultAttr
  .withName("label")
  .withValues("0.0", "1.0", "2.0")
  .toMetadata

rf.setLabelCol("label_meta").fit(
  df.withColumn("label_meta", $"label".as("", meta))
)

Note:

Labels created using StringIndexer depend on the frequency not value:

indexer.labels
// Array[String] = Array(2.0, 0.0, 1.0)

PySpark:

In Python metadata fields can be set directly on the schema:

from pyspark.sql.types import StructField, DoubleType

StructField(
    "label", DoubleType(), False,
    {"ml_attr": {
        "name": "label",
        "type": "nominal", 
        "vals": ["0.0", "1.0", "2.0"]
    }}
)
like image 62
zero323 Avatar answered Oct 09 '22 08:10

zero323