Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

GridSearchCV on LogisticRegression in scikit-learn

I am trying to optimize a logistic regression function in scikit-learn by using a cross-validated grid parameter search, but I can't seem to implement it.

It says that Logistic Regression does not implement a get_params() but on the documentation it says it does. How can I go about optimizing this function on my ground truth?

>>> param_grid = {'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000] }
>>> clf = GridSearchCV(LogisticRegression(penalty='l2'), param_grid)
>>> clf
GridSearchCV(cv=None,
       estimator=LogisticRegression(C=1.0, intercept_scaling=1, dual=False, fit_intercept=True,
          penalty='l2', tol=0.0001),
       fit_params={}, iid=True, loss_func=None, n_jobs=1,
       param_grid={'C': [0.001, 0.01, 0.1, 1, 10, 100, 1000]},
       pre_dispatch='2*n_jobs', refit=True, score_func=None, verbose=0)
>>> clf = clf.fit(gt_features, labels)
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Library/Python/2.7/site-packages/scikit_learn-0.14_git-py2.7-macosx-10.8-x86_64.egg/sklearn/grid_search.py", line 351, in fit
    base_clf = clone(self.estimator)
  File "/Library/Python/2.7/site-packages/scikit_learn-0.14_git-py2.7-macosx-10.8-x86_64.egg/sklearn/base.py", line 42, in clone
    % (repr(estimator), type(estimator)))
TypeError: Cannot clone object 'LogisticRegression(C=1.0, intercept_scaling=1, dual=False, fit_intercept=True,
          penalty='l2', tol=0.0001)' (type <class 'scikits.learn.linear_model.logistic.LogisticRegression'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_params' methods.
>>> 
like image 964
genekogan Avatar asked Sep 26 '13 02:09

genekogan


2 Answers

The class name scikits.learn.linear_model.logistic.LogisticRegression refers to a very old version of scikit-learn. The top level package name is now sklearn since at least 2 or 3 releases. It's very likely that you have old versions of scikit-learn installed concurrently in your python path. Uninstall them all, then reinstall 0.14 or later and try again.

like image 195
ogrisel Avatar answered Sep 21 '22 16:09

ogrisel


You can also give penalty as a parameter along with C. E.g. :

grid_values = {'penalty': ['l1','l2'], 'C': [0.001,0.01,0.1,1,10,100,1000]}. and then, model_lr = GridSearchCV(lr, param_grid=grid_values)

like image 38
Biswajit Ghoshal Avatar answered Sep 21 '22 16:09

Biswajit Ghoshal