Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Sklearn - define get_params() automatically

I'm trying to define a class that qualifies for an estimator in Sklearn, e.g.

class MyEstimator():
    def __init__(self,verbose=False):
        self.verbose = verbose

    def get_params(self, deep=False):
        return {
            'verbose': self.verbose,
        }

    def set_params(self, **parameters):
        for parameter, value in parameters.items():
            setattr(self, parameter, value)
        return self

    # Also def fit() and other stuff ...

Question

set_params() could be defined without explicitly listing all parameter names. Is there a way to define get_params() in a similar way?

What I need from Sklearn is GridsearchCV, and from what I have tried, it seems get_params determines what parameters can be injected during cross validation.

like image 332
diadochos Avatar asked Dec 19 '22 00:12

diadochos


1 Answers

Just inherit your class from BaseEstimator, which implements get_params() and set_params() for you.

Demo:

In [21]: from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin, ClusterMixin

In [22]: from sklearn.base import BaseEstimator
    ...:
    ...: class MyEstimator(BaseEstimator):
    ...:     def __init__(self,verbose=False):
    ...:         self.verbose = verbose

In [23]: est = MyEstimator(verbose=True)

In [24]: est.get_params()
Out[24]: {'verbose': True}

In [25]: est.set_params(verbose=False)
Out[25]: MyEstimator(verbose=False)

In [26]: est.get_params()
Out[26]: {'verbose': False}

PS you may also want to inherit your estimator also from one of (ClassifierMixin, RegressorMixin, ClusterMixin), depending of what kind of estimator you are going to implement...

like image 176
MaxU - stop WAR against UA Avatar answered Dec 26 '22 12:12

MaxU - stop WAR against UA