Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to draw a precision-recall curve with interpolation in python?

I have drawn a precision-recall curve using sklearn precision_recall_curvefunction and matplotlib package. For those of you who are familiar with precision-recall curve you know that some scientific communities only accept it when its interpolated, similar to this example here. Now my question is if any of you know how to do the interpolation in python? I have been searching for a solution for a while now but with no success! Any help would be greatly appreciated.

Solution: Both solutions by @francis and @ali_m are correct and together solved my problem. So, assuming that you get an output from the precision_recall_curve function in sklearn, here is what I did to plot the graph:

precision["micro"], recall["micro"], _ = precision_recall_curve(y_test.ravel(),scores.ravel())
pr = copy.deepcopy(precision[0])
rec = copy.deepcopy(recall[0])
prInv = np.fliplr([pr])[0]
recInv = np.fliplr([rec])[0]
j = rec.shape[0]-2
while j>=0:
    if prInv[j+1]>prInv[j]:
        prInv[j]=prInv[j+1]
    j=j-1
decreasing_max_precision = np.maximum.accumulate(prInv[::-1])[::-1]
plt.plot(recInv, decreasing_max_precision, marker= markers[mcounter], label=methodNames[countOfMethods]+': AUC={0:0.2f}'.format(average_precision[0]))

And these lines will plot the interpolated curves if you put them in a for loop and pass it the data of each method at each iteration. Note that this will not plot the non-interpolated precision-recall curves.

like image 662
user823743 Avatar asked Oct 03 '16 17:10

user823743


1 Answers

@francis's solution can be vectorized using np.maximum.accumulate.

import numpy as np
import matplotlib.pyplot as plt

recall = np.linspace(0.0, 1.0, num=42)
precision = np.random.rand(42)*(1.-recall)

# take a running maximum over the reversed vector of precision values, reverse the
# result to match the order of the recall vector
decreasing_max_precision = np.maximum.accumulate(precision[::-1])[::-1]

You can also use plt.step to get rid of the for loop used for plotting:

fig, ax = plt.subplots(1, 1)
ax.hold(True)
ax.plot(recall, precision, '--b')
ax.step(recall, decreasing_max_precision, '-r')

enter image description here

like image 74
ali_m Avatar answered Oct 14 '22 14:10

ali_m