Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to implement a meta-estimator with the scikit-learn API?

I would like to implement a simple wrapper / meta-estimator which is compatible with all of scikit-learn. It is hard to find a full description of what exactly I need.

The goal is to have a regressor which also learns a threshold to become a classifier. So I came up with:

from sklearn.base import BaseEstimator, ClassifierMixin, clone

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        # threshold_ does not get initialized in __init__ ??

    def fit(self, X, y, optimal_threshold):
        self.regressor = clone(self.regressor)    # is this required my sklearn??
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        self.threshold_ = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y

Is this implement the full API I need?

My main question is where to put the threshold. I want that it gets learned only once and can be re-used in subsequent .fit calls with new data without being readjusted. But with the current version it has to be retuned on every .fit call - which I do not want?

On the other hand, if I make it a fixed parameter self.threshold and pass it to __init__, then I'm not supposed to change it with the data?

How can I make a threshold parameter which can be tuned in one call of .fit and be fixed for subsequent .fit calls?

like image 596
Gerenuk Avatar asked Nov 11 '19 15:11


People also ask

What is estimator API in scikit-learn?

What is Estimator API. It is one of the main APIs implemented by Scikit-learn. It provides a consistent interface for a wide range of ML applications that's why all machine learning algorithms in Scikit-Learn are implemented via Estimator API. The object that learns from the data (fitting the data) is an estimator.

What is a meta estimator?

meta-estimator meta-estimators metaestimator metaestimators. An estimator which takes another estimator as a parameter. Examples include pipeline. Pipeline , model_selection.

What is the difference between sklearn and scikit-learn?

Essentially, sklearn is a dummy project on PyPi that will in turn install scikit-learn . Therefore, if you uninstall sklearn you are just uninstalling the dummy package, and not the actual package itself.

1 Answers

I actually wrote a blog post about this the other day. I assume you are trying to build something similar to TransformedTargetRegressor I would suggest taking a look at its source code to build something similar.

Your current implementation seems about right. As far as this concern goes:

How can I make a threshold parameter which can be tuned in one call of .fit and be fixed for subsequent .fit calls?

I would suggest against that because scikit-learn's API is based around the fit method re-fitting all tunable aspects of the model. There are two routes you can go here, either add a **kwarg to the fit that explicitly protects the theshold from updating or you can go with what @rotem-tal suggested. If you choose the latter, it might look something like this:

import numpy as np
from sklearn.base import BaseEstimator, ClassifierMixin

def optimal_threshold(y_raw: np.ndarray) -> np.ndarray:
    return np.array([0.1, 0.5, 1])  # some implementation here

class Thresholder(BaseEstimator, ClassifierMixin):
    def __init__(self, regressor):
        self.regressor = regressor
        self.threshold = None

    def fit(self, X, y, optimal_threshold):
        # you don't need to clone the regressor
        self.regressor.fit(X, y)

        y_raw = self.regressor.predict()
        if self.threshold is None:
            self.threshold = optimal_threshold(y_raw)

    def predict(self, X):
        y_raw = self.regressor.predict(X)

        y = np.digitize(y_raw, [self.threshold_])

        return y
like image 54
Adithya Avatar answered Sep 21 '22 15:09
