Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to set a threshold for a sklearn classifier based on ROC results?

I trained an ExtraTreesClassifier (gini index) using scikit-learn and it suits my needs fairly. Not so good accuracy, but using a 10-fold cross validation, AUC is 0.95. I would like to use this classifier on my work. I am quite new to ML, so please forgive me if I'm asking you something conceptually wrong.

I plotted some ROC curves, and by it, its seems I have a specific threshold where my classifier starts performing well. I'd like to set this value on the fitted classifier, so everytime I'd call predict, the classifiers use that threshold and I could believe in the FP and TP rates.

I also came to this post (scikit .predict() default threshold), where its stated that a threshold is not a generic concept for classifiers. But since the ExtraTreesClassifier has the method predict_proba, and the ROC curve is also related to thresdholds definition, it seems to me I should be available to specify it.

I did not find any parameter, nor any class/interface to use to do it. How can I set a threshold for it for a trained ExtraTreesClassifier (or any other one) using scikit-learn?

Many Thanks, Colis

like image 263
Colis Avatar asked Jan 26 '17 00:01

Colis


People also ask

How do you select threshold with ROC curve?

A really easy way to pick a threshold is to take the median predicted values of the positive cases for a test set. This becomes your threshold. The threshold comes relatively close to the same threshold you would get by using the roc curve where true positive rate(tpr) and 1 - false positive rate(fpr) overlap.

How do you set threshold in confusion matrix?

We can set a threshold value to classify all the values greater than threshold as 1 and lesser then that as 0. That's how the Y is predicted and we get 'Y-predicted'. The default value for threshold on which we generally get a Confusion Matrix is 0.50. This is where things start to get interesting.

What is the threshold in Roc?

The ROC curve is produced by calculating and plotting the true positive rate against the false positive rate for a single classifier at a variety of thresholds. For example, in logistic regression, the threshold would be the predicted probability of an observation belonging to the positive class.

How is threshold determined for classification algorithms?

A classification threshold value must be defined if you want to transfer a logistic regression value to a binary category. A value greater than that denotes “spam,” whereas a value less than that suggests “not spam.” It's easy to assume that the classification threshold is always going to be 0.5…


1 Answers

This is what I have done:

model = SomeSklearnModel()
model.fit(X_train, y_train)
predict = model.predict(X_test)
predict_probabilities = model.predict_proba(X_test)
fpr, tpr, _ = roc_curve(y_test, predict_probabilities)

However, I am annoyed that predict chooses a threshold corresponding to 0.4% of true positives (false positives are zero). The ROC curve shows a threshold I like better for my problem where the true positives are approximately 20% (false positive around 4%). I then scan the predict_probabilities to find what probability value corresponds to my favourite ROC point. In my case this probability is 0.21. Then I create my own predict array:

predict_mine = np.where(rf_predict_probabilities > 0.21, 1, 0)

and there you go:

confusion_matrix(y_test, predict_mine)

returns what I wanted:

array([[6927,  309],
       [ 621,  121]])
like image 114
famargar Avatar answered Oct 04 '22 12:10

famargar