Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What's the full specification for implementing a custom scikit-learn estimator?

I'm rolling my own predictor and want to use it like I would use any of the scikit routines (e.g. RandomForestRegressor). I have a class containing fit and predict methods that seem to work fine. However, when I try to use some of the scikit methods, such as cross validation, I get errors like:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in cross_val_
score
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\externals\joblib\parallel.py", line 516, in __
call__
    for function, args, kwargs in iterable:
  File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in <genexpr>
    for train, test in cv)
  File "C:\Python27\lib\site-packages\sklearn\base.py", line 43, in clone
    % (repr(estimator), type(estimator)))
TypeError: Cannot clone object '<__main__.Custom instance at 0x033A6990>' (type <type 'inst
ance'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_para
ms' methods.

I see that it wants me to implement some methods (presumably get_params as well as maybe set_params and score) but I'm not sure what the right specification for making these methods is. Is there some information available on this topic? Thanks.

like image 581
rhombidodecahedron Avatar asked May 26 '14 09:05

rhombidodecahedron


People also ask

What are estimators in scikit-learn?

Estimators objects Fitting data: the main API implemented by scikit-learn is that of the estimator . An estimator is any object that learns from data; it may be a classification, regression or clustering algorithm or a transformer that extracts/filters useful features from raw data.


1 Answers

Full instructions are available in the scikit-learn docs, and the principles behind the API are set out in this paper by yours truly et al. In short, besides fit, what you need for an estimator are get_params and set_params that return (as a dict) and set (from kwargs) the hyperparameters of the estimator, i.e. the parameters of the learning algorithm itself (as opposed to the data parameters it learns). These parameters should match the __init__ parameters.

Both methods can be obtained by inheriting from the classes in sklearn.base, but you can provide them yourself if you don't want your code to be dependent on scikit-learn.

Note that input validation should be done in fit, not the constructor, because otherwise you can still set invalid parameters in set_params and have fit fail in unexpected ways.

like image 52
Fred Foo Avatar answered Oct 13 '22 22:10

Fred Foo