I'm rolling my own predictor and want to use it like I would use any of the scikit routines (e.g. RandomForestRegressor). I have a class containing fit
and predict
methods that seem to work fine. However, when I try to use some of the scikit methods, such as cross validation, I get errors like:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in cross_val_
score
for train, test in cv)
File "C:\Python27\lib\site-packages\sklearn\externals\joblib\parallel.py", line 516, in __
call__
for function, args, kwargs in iterable:
File "C:\Python27\lib\site-packages\sklearn\cross_validation.py", line 1152, in <genexpr>
for train, test in cv)
File "C:\Python27\lib\site-packages\sklearn\base.py", line 43, in clone
% (repr(estimator), type(estimator)))
TypeError: Cannot clone object '<__main__.Custom instance at 0x033A6990>' (type <type 'inst
ance'>): it does not seem to be a scikit-learn estimator a it does not implement a 'get_para
ms' methods.
I see that it wants me to implement some methods (presumably get_params
as well as maybe set_params
and score
) but I'm not sure what the right specification for making these methods is. Is there some information available on this topic? Thanks.
Estimators objects Fitting data: the main API implemented by scikit-learn is that of the estimator . An estimator is any object that learns from data; it may be a classification, regression or clustering algorithm or a transformer that extracts/filters useful features from raw data.
Full instructions are available in the scikit-learn docs, and the principles behind the API are set out in this paper by yours truly et al. In short, besides fit
, what you need for an estimator are get_params
and set_params
that return (as a dict
) and set (from kwargs) the hyperparameters of the estimator, i.e. the parameters of the learning algorithm itself (as opposed to the data parameters it learns). These parameters should match the __init__
parameters.
Both methods can be obtained by inheriting from the classes in sklearn.base
, but you can provide them yourself if you don't want your code to be dependent on scikit-learn.
Note that input validation should be done in fit
, not the constructor, because otherwise you can still set invalid parameters in set_params
and have fit
fail in unexpected ways.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With