Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to plot ROC curve and precision-recall curve from BinaryClassificationMetrics

I was trying to plot ROC curve and Precision-Recall curve in graph. The points are generated from the Spark Mllib BinaryClassificationMetrics. By following the following Spark https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html

[(1.0,1.0), (0.0,0.4444444444444444)] Precision
[(1.0,1.0), (0.0,1.0)] Recall
[(1.0,1.0), (0.0,0.6153846153846153)] - F1Measure    
[(0.0,1.0), (1.0,1.0), (1.0,0.4444444444444444)]- Precision-Recall curve
[(0.0,0.0), (0.0,1.0), (1.0,1.0), (1.0,1.0)] - ROC curve
like image 951
Desanth pv Avatar asked Jul 05 '16 16:07

Desanth pv


People also ask

How do you plot a ROC curve?

To plot the ROC curve, we need to calculate the TPR and FPR for many different thresholds (This step is included in all relevant libraries as scikit-learn ). For each threshold, we plot the FPR value in the x-axis and the TPR value in the y-axis. We then join the dots with a line. That's it!

Is ROC curve the same as precision-recall curve?

ROC curves are appropriate when the observations are balanced between each class, whereas precision-recall curves are appropriate for imbalanced datasets.


1 Answers

It looks like you have a similar problem to what I experienced. You need to either flip your parameters to the Metrics constructor or perhaps pass in the probability instead of the prediction. So, for example, if you are using the BinaryClassificationMetrics and a RandomForestClassifier, then according to this page (under outputs) there is "prediction" and "probability".

Then initialize your Metrics thus:

    new BinaryClassificationMetrics(predictionsWithResponse
      .select(col("probability"),col("myLabel"))
      .rdd.map(r=>(r.getAs[DenseVector](0)(1),r.getDouble(1))))

With the DenseVector call used to extract the probability of the 1 class.

As for actual plotting, that's up to you (many fine tools for that), but at least you will get more than 1 point on you curve (besides the endpoints).

And in case it's not clear:

metrics.roc().collect() will give you the data for the ROC curve: Tuples of: (false positive rate, true positive rate).

like image 153
Jeremy Avatar answered Oct 11 '22 19:10

Jeremy