Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Making a ML model scikit-learn compatible

I want to make this ML model scikit-learn compatible: https://github.com/manifoldai/merf

To do that, I followed the instructions here: https://danielhnyk.cz/creating-your-own-estimator-scikit-learn/ and imported from sklearn.base import BaseEstimator, RegressorMixin and inherited from them like so: class MERF(BaseEstimator, RegressorMixin):

However, when I check for scikit-learn compatibility:

from sklearn.utils.estimator_checks import check_estimator

import merf
check_estimator(merf)

I get this error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 500, in check_estimator
    for estimator, check in checks_generator:
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in _generate_instance_checks
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 340, in <genexpr>
    yield from ((estimator, partial(check, name))
  File "C:\Users\hap\anaconda3\envs\a1\lib\site-packages\sklearn\utils\estimator_checks.py", line 232, in _yield_all_checks
    tags = estimator._get_tags()
AttributeError: module 'merf' has no attribute '_get_tags'

How do I make this model scikit-learn compatible?

like image 751
user308827 Avatar asked May 09 '21 00:05

user308827


People also ask

Is scikit-learn good for machine learning?

If you're learning Python and would like to develop a machine learning model then a library that you want to seriously consider is scikit-learn.

Is scikit-learn a ML library?

Scikit-learn is an open source data analysis library, and the gold standard for Machine Learning (ML) in the Python ecosystem. Key concepts and features include: Algorithmic decision-making methods, including: Classification: identifying and categorizing data based on patterns.

What does the Fit () method do?

The fit() method takes the training data as arguments, which can be one array in the case of unsupervised learning, or two arrays in the case of supervised learning. Note that the model is fitted using X and y , but the object holds no reference to X and y .


1 Answers

From the docs, check_estimator is used to "Check if estimator adheres to scikit-learn conventions."

This estimator will run an extensive test-suite for input validation, shapes, etc, making sure that the estimator complies with scikit-learn conventions as detailed in Rolling your own estimator. Additional tests for classifiers, regressors, clustering or transformers will be run if the Estimator class inherits from the corresponding mixin from sklearn.base.

So check_estimator is more than just a compatibility check, it also checks if you follow all the conventions etc.

You can read up on rolling your own estimator to make sure you follow the convention.

And then you need to pass an instance of your estimator class to check esimator like check_estimator(MERF()). To actually make it follow all the conventions you have to solve every error it throws and fix them one by one.

For example one such check is that the __init__ method only set those attributes that it accepts as parameters.

MERF class violates that:

    def __init__(
        self,
        fixed_effects_model=RandomForestRegressor(n_estimators=300, n_jobs=-1),
        gll_early_stop_threshold=None,
        max_iterations=20,
    ):
        self.gll_early_stop_threshold = gll_early_stop_threshold
        self.max_iterations = max_iterations

        self.cluster_counts = None
        # Note fixed_effects_model must already be instantiated when passed in.
        self.fe_model = fixed_effects_model
        self.trained_fe_model = None
        self.trained_b = None

        self.b_hat_history = []
        self.sigma2_hat_history = []
        self.D_hat_history = []
        self.gll_history = []
        self.val_loss_history = []

It is setting attributes such as self.b_hat_history even though they are not parameters.

There are lots of other checks like this.

My personal advice is to not check all these conditions unless necessary, just inherit the Mixins and the Base classes, implement the needed methods and use the model.

like image 134
Vikash Balasubramanian Avatar answered Oct 09 '22 22:10

Vikash Balasubramanian