Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to duplicate an estimator in order to use it on multiple data sets?

Here is an example that creates two data sets:

from sklearn.linear_model import LogisticRegression from sklearn.datasets import make_classification  # data set 1 X1, y1 = make_classification(n_classes=2, n_features=5, random_state=1) # data set 2 X2, y2 = make_classification(n_classes=2, n_features=5, random_state=2) 

I want to use the LogisticRegression estimator with the same parameter values to fit a classifier on each data set:

lr = LogisticRegression()  clf1 = lr.fit(X1, y1) clf2 = lr.fit(X2, y2)  print "Classifier for data set 1: " print "  - intercept: ", clf1.intercept_ print "  - coef_: ", clf1.coef_  print "Classifier for data set 2: " print "  - intercept: ", clf2.intercept_ print "  - coef_: ", clf2.coef_ 

The problem is that both classifiers are the same:

Classifier for data set 1:    - intercept:  [ 0.05191729]   - coef_:  [[ 0.06704494  0.00137751 -0.12453698 -0.05999127  0.05798146]] Classifier for data set 2:    - intercept:  [ 0.05191729]   - coef_:  [[ 0.06704494  0.00137751 -0.12453698 -0.05999127  0.05798146]] 

For this simple example, I could use something like:

lr1 = LogisticRegression() lr2 = LogisticRegression()  clf1 = lr1.fit(X1, y1) clf2 = lr2.fit(X2, y2) 

to avoid the problem. However, the question remains: How to duplicate / copy an estimator with its particular parameter values in general?

like image 330
tjanez Avatar asked Dec 04 '12 11:12

tjanez


People also ask

What does the Fit () method do?

The fit() method takes the training data as arguments, which can be one array in the case of unsupervised learning, or two arrays in the case of supervised learning. Note that the model is fitted using X and y , but the object holds no reference to X and y .

What is clone in Sklearn?

Clone does a deep copy of the model in an estimator without actually copying attached data. It returns a new estimator with the same parameters that has not been fitted on any data. Parameters: estimator{list, tuple, set} of estimator instance or a single estimator instance.

What is a scikit-learn estimator?

Estimators objects Fitting data: the main API implemented by scikit-learn is that of the estimator . An estimator is any object that learns from data; it may be a classification, regression or clustering algorithm or a transformer that extracts/filters useful features from raw data.


1 Answers

from sklearn.base import clone  lr1 = LogisticRegression() lr2 = clone(lr1) 
like image 180
Fred Foo Avatar answered Sep 28 '22 05:09

Fred Foo