Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

scikit learn: custom classifier compatible with GridSearchCV

I have implemented my own classifier and now I want to run a grid search over it, but I'm getting the following error: estimator.fit(X_train, y_train, **fit_params) TypeError: fit() takes 2 positional arguments but 3 were given

I followed this tutorial and used this template provided by scikit's official documentation. My class is defined as follows:

class MyClassifier(BaseEstimator, ClassifierMixin):
    def __init__(self, lr=0.1):
        self.lr=lr

    def fit(self, X, y):
        # Some code
        return self
    def predict(self, X):
        # Some code
        return y_pred
    def get_params(self, deep=True)
        return {'lr'=self.lr}
    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

And I'm trying to grid search throw it as follows:

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

EDIT I

This is how I'm calling it: gs.fit(['hello world', 'trying','hello world', 'trying', 'hello world', 'trying', 'hello world', 'trying'], ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])

END EDIT I

The error is produced by _fit_and_score method in file python3.5/site-packages/sklearn/model_selection/_validation.py

It is calling estimator.fit(X_train, y_train, **fit_params) with 3 arguments, but my estimator only have two, so the error makes sense for me, but I don't know how to solve it... I also tried adding some dummy arguments to fit method but it didn't work.

EDIT II

Complete error output:

Traceback (most recent call last):
  File "/home/rodrigo/no_version/text_classifier/MyClassifier.py", line 355, in <module>
    ['I', 'Z', 'I', 'Z', 'I', 'Z', 'I', 'Z'])
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_search.py", line 639, in fit
    cv.split(X, y, groups)))
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 779, in __call__
    while self.dispatch_one_batch(iterator):
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 625, in dispatch_one_batch
    self._dispatch(tasks)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 588, in _dispatch
    job = self._backend.apply_async(batch, callback=cb)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 111, in apply_async
    result = ImmediateResult(func)
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/_parallel_backends.py", line 332, in __init__
    self.results = batch()
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in __call__
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/externals/joblib/parallel.py", line 131, in <listcomp>
    return [func(*args, **kwargs) for func, args, kwargs in self.items]
  File "/home/rodrigo/no_version/text_classifier/.env/lib/python3.5/site-packages/sklearn/model_selection/_validation.py", line 458, in _fit_and_score
    estimator.fit(X_train, y_train, **fit_params)
TypeError: fit() takes 2 positional arguments but 3 were given

END EDIT II

SOLVED Thanks you all, I had a stupid mistake: there was two different functions with same name (fit), (I implemented the other for custom purposes with different parameters, as soon as I renamed my 'custom fit', it worked correctly.)

Thank you and sorry

like image 301
Rodrigo Laguna Avatar asked Jan 11 '18 16:01

Rodrigo Laguna


People also ask

What is a GridSearchCV in Sklearn?

GridSearchCV is a technique to search through the best parameter values from the given set of the grid of parameters. It is basically a cross-validation method. the model and the parameters are required to be fed in. Best parameter values are extracted and then the predictions are made.

Is GridSearchCV cross-validation?

Yes, GridSearchCV performs cross-validation. If I understand the concept correctly - you want to keep part of your data set unseen for the model in order to test it. So you train your models against train data set and test them on a testing data set.

What does GridSearchCV fit do?

GridSearchCV tries all the combinations of the values passed in the dictionary and evaluates the model for each combination using the Cross-Validation method. Hence after using this function we get accuracy/loss for every combination of hyperparameters and we can choose the one with the best performance.


1 Answers

The following code works for me:

class MyClassifier(BaseEstimator, ClassifierMixin):
     def __init__(self, lr=0.1):
         self.lr = lr
         # Some code
         pass
     def fit(self, X, y):
         # Some code
         pass
     def predict(self, X):
         # Some code
         return X % 3

params = {
    'lr': [0.1, 0.5, 0.7]
}
gs = GridSearchCV(MyClassifier(), param_grid=params, cv=4)

x = np.arange(30)
y = np.concatenate((np.zeros(10), np.ones(10), np.ones(10) * 2))
gs.fit(x, y)

The best I can figure is that you are passing something into the gs.fit method beyond x and y or your MyClassifier.fit method is missing the self argument.

The fit_params kwargs should only be populated if you pass a kwarg to the gs.fit method otherwise it is an empty dictionary ({}) and **fit_params won't throw an argument error. To test this create an instance of your classifier and pass **{}. For example:

clf = MyClassifier()
clf.fit(x, y, **{})

This does not throw the positional arguments error.

Therefore, again unless something is passed to gs.fit e.g. gs.fit(x, y, some_arg=123) it would seem to me that you are missing one of the positional arguments in the definition of MyClassifier.fit. The error message you included seems to support this hypothesis as it states fit() takes 2 positional arguments but 3 were given. If you had defined fit as follows it would take 3 positional arguments:

def fit(self, X, y): ...
like image 103
Grr Avatar answered Nov 09 '22 23:11

Grr