I was trying to plot ROC curve and Precision-Recall curve in graph. The points are generated from the Spark Mllib BinaryClassificationMetrics. By following the following Spark https://spark.apache.org/docs/latest/mllib-evaluation-metrics.html
[(1.0,1.0), (0.0,0.4444444444444444)] Precision
[(1.0,1.0), (0.0,1.0)] Recall
[(1.0,1.0), (0.0,0.6153846153846153)] - F1Measure
[(0.0,1.0), (1.0,1.0), (1.0,0.4444444444444444)]- Precision-Recall curve
[(0.0,0.0), (0.0,1.0), (1.0,1.0), (1.0,1.0)] - ROC curve
To plot the ROC curve, we need to calculate the TPR and FPR for many different thresholds (This step is included in all relevant libraries as scikit-learn ). For each threshold, we plot the FPR value in the x-axis and the TPR value in the y-axis. We then join the dots with a line. That's it!
ROC curves are appropriate when the observations are balanced between each class, whereas precision-recall curves are appropriate for imbalanced datasets.
It looks like you have a similar problem to what I experienced. You need to either flip your parameters to the Metrics constructor or perhaps pass in the probability instead of the prediction. So, for example, if you are using the BinaryClassificationMetrics
and a RandomForestClassifier
, then according to this page (under outputs) there is "prediction" and "probability".
Then initialize your Metrics thus:
new BinaryClassificationMetrics(predictionsWithResponse
.select(col("probability"),col("myLabel"))
.rdd.map(r=>(r.getAs[DenseVector](0)(1),r.getDouble(1))))
With the DenseVector call used to extract the probability of the 1 class.
As for actual plotting, that's up to you (many fine tools for that), but at least you will get more than 1 point on you curve (besides the endpoints).
And in case it's not clear:
metrics.roc().collect()
will give you the data for the ROC curve: Tuples of: (false positive rate, true positive rate).
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With