Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to set threshold to scikit learn random forest model

After seeing the precision_recall_curve, if I want to set threshold = 0.4, how to implement 0.4 into my random forest model (binary classification), for any probability <0.4, label it as 0, for any >=0.4, label it as 1.

from sklearn.ensemble import RandomForestClassifier
  random_forest = RandomForestClassifier(n_estimators=100, oob_score=True, random_state=12)
  random_forest.fit(X_train, y_train)
from sklearn.metrics import accuracy_score
  predicted = random_forest.predict(X_test)
accuracy = accuracy_score(y_test, predicted)

Documentation Precision recall

like image 818
BigData Avatar asked Apr 11 '18 23:04

BigData


People also ask

How do you set the threshold value in random forest?

After seeing the precision_recall_curve, if I want to set threshold = 0.4, how to implement 0.4 into my random forest model (binary classification), for any probability <0.4, label it as 0, for any >=0.4, label it as 1.

How do I change threshold in decision tree?

Yes, you can easily do this. A sklearn Decision Tree exposes its underlying tree through the tree_ attribute. This tree_ , among other things, have an attribute threshold , which is a numpy array containing threshold values of all nodes. You can modify this array, thereby changing the thresholds.

How do you choose threshold for binary classification?

In binary classification, when a model gives us a score instead of the prediction itself, we usually need to convert this score into a prediction applying a threshold. Since the meaning of the score is to give us the perceived probability of having 1 according to our model, it's obvious to use 0.5 as a threshold.

How do you find the optimal threshold of a ROC curve?

ROC curve for finding the optimal thresholdThe X-axis or independent variable is the false positive rate for the predictive test. The Y-axis or dependent variable is the true positive rate for the predictive test. A perfect result would be the point (0, 1) indicating 0% false positives and 100% true positives.


2 Answers

Assuming you are doing binary classification, it's quite easy:

threshold = 0.4

predicted_proba = random_forest.predict_proba(X_test)
predicted = (predicted_proba [:,1] >= threshold).astype('int')

accuracy = accuracy_score(y_test, predicted)
like image 63
Stev Avatar answered Oct 09 '22 15:10

Stev


random_forest = RandomForestClassifier(n_estimators=100)
random_forest.fit(X_train, y_train)

threshold = 0.4

predicted = random_forest.predict_proba(X_test)
predicted[:,0] = (predicted[:,0] < threshold).astype('int')
predicted[:,1] = (predicted[:,1] >= threshold).astype('int')


accuracy = accuracy_score(y_test, predicted)
print(round(accuracy,4,)*100, "%")

this comes with an error refers to the last accuracy part" ValueError: Can't handle mix of binary and multilabel-indicator"

like image 42
BigData Avatar answered Oct 09 '22 15:10

BigData