Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to get accuracy precision, recall and ROC from cross validation in Spark ml lib?

I am using Spark 2.0.2. I am also using the "ml" library for Machine Learning with Datasets. What I want to do is run algorithms with cross validation and extract the mentioned metrics (accuracy, precision, recall, ROC, confusion matrix). My data labels are binary.

By using the MulticlassClassificationEvaluator I can only get the accuracy of the algorithm by accessing "avgMetrics". Also, by using the BinaryClassificationEvaluator I can get the area under ROC. But I cannot use them both. So, is there a way that I can extract all of the wanted metrics?

like image 471
user3309479 Avatar asked Jan 18 '17 08:01

user3309479


People also ask

What is the difference between spark ml and spark MLlib?

Choosing Between Spark MLlib and Spark ML At first glance, the most obvious difference between MLlib and ML is the data types they work on, with MLlib supporting RDDs and ML supporting DataFrame s and Dataset s.

What are precision and recall metrics?

Precision and recall are two extremely important model evaluation metrics. While precision refers to the percentage of your results which are relevant, recall refers to the percentage of total relevant results correctly classified by your algorithm.

What is precision recall accuracy in ML?

Accuracy tells you how many times the ML model was correct overall. Precision is how good the model is at predicting a specific category. Recall tells you how many times the model was able to detect a specific category.


2 Answers

Have tried to use MLlib to evaluate your result.

I've transformed the dataset to RDD, then used MulticlassMetrics in MLlib

You can see a demo here: Spark DecisionTreeExample.scala

private[ml] def evaluateClassificationModel(
      model: Transformer,
      data: DataFrame,
      labelColName: String): Unit = {
    val fullPredictions = model.transform(data).cache()
    val predictions = fullPredictions.select("prediction").rdd.map(_.getDouble(0))
    val labels = fullPredictions.select(labelColName).rdd.map(_.getDouble(0))
    // Print number of classes for reference.
    val numClasses = MetadataUtils.getNumClasses(fullPredictions.schema(labelColName)) match {
      case Some(n) => n
      case None => throw new RuntimeException(
        "Unknown failure when indexing labels for classification.")
    }
    val accuracy = new MulticlassMetrics(predictions.zip(labels)).accuracy
    println(s"  Accuracy ($numClasses classes): $accuracy")
  }
like image 93
ShuoshuoFan Avatar answered Oct 03 '22 15:10

ShuoshuoFan


You can follow the official Evaluation Metrics guide provided by Apache Spark. The document has provided all the Evaluation Metrics including

  • Precision (Positive Predictive Value), Recall (True Positive Rate), F-measure, Receiver Operating Characteristic (ROC), Area Under ROC Curve, Area Under Precision-Recall Curve.

Here is the link : https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html

like image 26
Darshan Avatar answered Oct 03 '22 14:10

Darshan