Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Why does Spark ML NaiveBayes output labels that are different from the training data?

I use the NaiveBayes classifier in Apache Spark ML (version 1.5.1) to predict some text categories. However, the classifier outputs labels that are different from the labels in my training set. Am I doing it wrong?

Here is a small example that can be pasted into e.g. Zeppelin notebook:

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.NaiveBayes
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.Row

// Prepare training documents from a list of (id, text, label) tuples.
val training = sqlContext.createDataFrame(Seq(
  (0L, "X totally sucks :-(", 100.0),
  (1L, "Today was kind of meh", 200.0),
  (2L, "I'm so happy :-)", 300.0)
)).toDF("id", "text", "label")

// Configure an ML pipeline, which consists of three stages: tokenizer, hashingTF, and lr.
val tokenizer = new Tokenizer()
  .setInputCol("text")
  .setOutputCol("words")
val hashingTF = new HashingTF()
  .setNumFeatures(1000)
  .setInputCol(tokenizer.getOutputCol)
  .setOutputCol("features")
val nb = new NaiveBayes()

val pipeline = new Pipeline()
  .setStages(Array(tokenizer, hashingTF, nb))

// Fit the pipeline to training documents.
val model = pipeline.fit(training)

// Prepare test documents, which are unlabeled (id, text) tuples.
val test = sqlContext.createDataFrame(Seq(
  (4L, "roller coasters are fun :-)"),
  (5L, "i burned my bacon :-("),
  (6L, "the movie is kind of meh")
)).toDF("id", "text")

// Make predictions on test documents.
model.transform(test)
  .select("id", "text", "prediction")
  .collect()
  .foreach { case Row(id: Long, text: String, prediction: Double) =>
    println(s"($id, $text) --> prediction=$prediction")
  }

The output from the small program:

(4, roller coasters are fun :-)) --> prediction=2.0
(5, i burned my bacon :-() --> prediction=0.0
(6, the movie is kind of meh) --> prediction=1.0

The set of predicted labels {0.0, 1.0, 2.0} are disjoint from my training set labels {100.0, 200.0, 300.0}.

Question: How can I map these predicted labels back to my original training set labels?

Bonus question: why do the training set labels have to be doubles, when any other type would work just as well as a label? Seems unnecessary.

like image 745
Pimin Konstantin Kefaloukos Avatar asked Sep 27 '22 05:09

Pimin Konstantin Kefaloukos


People also ask

Which of these is an object used for classification in spark?

For classification tasks in Spark, you have logistic regression, naïve Bayes, support vector machines (SVM), decision trees, and random forests at your disposal.

What is Spark used for in ML?

Spark ML standardizes APIs for machine learning algorithms to make it easier to combine multiple algorithms into a single pipeline, or workflow. This section covers the key concepts introduced by the Spark ML API.


1 Answers

However, the classifier outputs labels that are different from the labels in my training set. Am I doing it wrong?

Kind of. As far as I can tell you're hitting the issue described by SPARK-9137. Generally speaking all classifiers in ML expect 0 based labels (0.0, 1.0, 2.0, ...) but there is no validations step in ml.NaiveBayes. Under the hood data is passed to mllib.NaiveBayes which doesn't doesn't have this limitation so training process works smoothly.

When model is transformed back to ml, prediction function simply assumes that labels where correct, and returns predicted label using Vector.argmax, hence the results you get. You can fix the labels using for example StringIndexer.

why do the training set labels have to be doubles, when any other type would work just as well as a label?

I guess it is mostly a matter of keeping simple and reusable API. This way LabeledPoint can be used for both classification and regression problems. Moreover it is an efficient representation in terms of memory usage and computational cost.

like image 64
zero323 Avatar answered Sep 29 '22 07:09

zero323