Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Print Estimator Name in SkLearn

In Sklearn, is there a way to print out an estimator's class name?

I have tried to use the name attribute but that is not working.

from sklearn.linear_model import LogisticRegression  

def print_estimator_name(estimator):
    print(estimator.__name__)

#Expected Outcome:
print_estimator_name(LogisticRegression())

I would expect this to print out the classifier name as above

like image 713
Odisseo Avatar asked Jan 05 '19 06:01

Odisseo


People also ask

What is estimator in Sklearn?

Fitting data: the main API implemented by scikit-learn is that of the estimator . An estimator is any object that learns from data; it may be a classification, regression or clustering algorithm or a transformer that extracts/filters useful features from raw data.

How does fit work Sklearn?

The fit() method takes the training data as arguments, which can be one array in the case of unsupervised learning, or two arrays in the case of supervised learning. Note that the model is fitted using X and y , but the object holds no reference to X and y .

Is Sklearn an API?

It is one of the main APIs implemented by Scikit-learn. It provides a consistent interface for a wide range of ML applications that's why all machine learning algorithms in Scikit-Learn are implemented via Estimator API. The object that learns from the data (fitting the data) is an estimator.


1 Answers

I have an alternative method. Get the object name, convert to str, get foremost child class with split("."), and finally strip off unwanted chars

str(type(clf)).split(".")[-1][:-2]

This work for me in SKLearn, XGBoost, and LightGBM

print(f'Acc: {pred:0.5f} for the {str(type(clf)).split(".")[-1][:-2])}')
Acc: 0.7159443 : DecisionTreeClassifier
Acc: 0.7572368 : RandomForestClassifier
Acc: 0.7548593 : ExtraTreesClassifier
Acc: 0.7416970 : XGBClassifier
Acc: 0.7582540 : LGBMClassifier
like image 148
Ali Pardhan Avatar answered Sep 22 '22 05:09

Ali Pardhan