Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Inherit from scikit-learn's LassoCV model

I tried to extend scikit-learn's RidgeCV model using inheritance:

from sklearn.linear_model import RidgeCV, LassoCV

class Extended(RidgeCV):
    def __init__(self, *args, **kwargs):
        super(Extended, self).__init__(*args, **kwargs)

    def example(self):
        print 'Foo'


x = [[1,0],[2,0],[3,0],[4,0], [30, 1]]
y = [2,4,6,8, 60]
model = Extended(alphas = [float(a)/1000.0 for a in range(1, 10000)])
model.fit(x,y)
print model.predict([[5,1]])

It worked perfectly fine, but when I tried to inherit from LassoCV, it yielded the following traceback:

Traceback (most recent call last):
  File "C:/Python27/so.py", line 14, in <module>
    model.fit(x,y)
  File "C:\Python27\lib\site-packages\sklearn\linear_model\coordinate_descent.py", line 1098, in fit
    path_params = self.get_params()
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 214, in get_params
    for key in self._get_param_names():
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 195, in _get_param_names
    % (cls, init_signature))
RuntimeError: scikit-learn estimators should always specify their parameters in the signature of their __init__ (no varargs). <class '__main__.Extended'> with constructor (<self>, *args, **kwargs) doesn't  follow this convention.

Can somebody explain how to fix this?

like image 750
Markus K Avatar asked Oct 13 '16 15:10

Markus K


People also ask

How do I import LassoCV into Python?

linear_model import LassoCV >>> from sklearn. datasets import make_regression >>> X, y = make_regression(noise=4, random_state=0) >>> reg = LassoCV(cv=5, random_state=0). fit(X, y) >>> reg. score(X, y) 0.9993... >>>

Does Scikit learn linear regression use gradient descent?

It doesn't provide gradient descent info.

What is Sklearn linear_model?

linear_model is a class of the sklearn module if contain different functions for performing machine learning with linear models. The term linear model implies that the model is specified as a linear combination of features.

What are the regression models in Sklearn?

Three types of Machine Learning Models can be implemented using the Sklearn Regression Models: Reinforced Learning. Unsupervised Learning. Supervised Learning.


1 Answers

You probably want to make scikit-learn compatible model, to use it further with available scikit-learn functional. If you do - you need to read this first: http://scikit-learn.org/stable/developers/contributing.html#rolling-your-own-estimator

Shortly: scikit-learn has many features like estimator cloning (clone() function), meta algorithms like GridSearch, Pipeline, Cross validation. And all these things have to be able to get values of fields inside of your estimator, and change value of these fields (For example GridSearch has to change parameters inside of your estimator before each evaluation), like parameter alpha in SGDClassifier. To change value of some parameter it has to know its name. To get names of all fields in every classifier method get_params from BaseEstimator class (Which you're inheriting implicitly) requires all parameters to be specified in __init__ method of a class, because it's easy to introspect all parameter names of __init__ method (Look at BaseEstimator, this is the class which throws this error).

So it just wants you to remove all varargs like

*args, **kwargs

from __init__ signature. You have to list all parameters of your model in __init__ signature, and initialise all internal fields of an object.

Here is example of __init__ method of SGDClassifier, which is inherited from BaseSGDClassifier:

def __init__(self, loss="hinge", penalty='l2', alpha=0.0001, l1_ratio=0.15,
             fit_intercept=True, n_iter=5, shuffle=True, verbose=0,
             epsilon=DEFAULT_EPSILON, n_jobs=1, random_state=None,
             learning_rate="optimal", eta0=0.0, power_t=0.5,
             class_weight=None, warm_start=False, average=False):
    super(SGDClassifier, self).__init__(
        loss=loss, penalty=penalty, alpha=alpha, l1_ratio=l1_ratio,
        fit_intercept=fit_intercept, n_iter=n_iter, shuffle=shuffle,
        verbose=verbose, epsilon=epsilon, n_jobs=n_jobs,
        random_state=random_state, learning_rate=learning_rate, eta0=eta0,
        power_t=power_t, class_weight=class_weight, warm_start=warm_start, average=average)
like image 121
Ibraim Ganiev Avatar answered Oct 21 '22 13:10

Ibraim Ganiev