Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pickle/joblib AttributeError: module '__main__' has no attribute 'thing' in pytest

I have built a custom sklearn pipeline, as follows:

pipeline = make_pipeline(
    SelectColumnsTransfomer(features_to_use),
    ToDummiesTransformer('feature_0', prefix='feat_0', drop_first=True,  dtype=bool), # Dummify customer_type
    ToDummiesTransformer('feature_1', prefix='feat_1'), # Dummify the feature
    ToDummiesTransformer('feature_2', prefix='feat_2'), # Dummify 
    ToDummiesTransformer('feature_3', prefix='feat_3'), # Dummify
)
pipeline.fit(df)

The classes SelectColumnsTransfomer and ToDummiesTransformer are custom sklearn steps implementing BaseEstimator and TransformerMixin. To serialise this object I use

from sklearn.externals import joblib
joblib.dump(pipeline, 'data_pipeline.joblib')

but when I do deserialise with

pipeline = joblib.load('data_pipeline.joblib') 

I get AttributeError: module '__main__' has no attribute 'SelectColumnsTransfomer'.

I have read other similar questions and followed the instruction in this blogpost here, but couldn't solve the issue. I am copying pasting the classes, and importing them in the code. If i create a simplified version of this exercise, the whole thing works, the problem occurs because i am running some tests with pytest, and when i run pytest it seems it doesn't see my custom classes, in fact there is this other part of the error self = <sklearn.externals.joblib.numpy_pickle.NumpyUnpickler object at 0x7f821508a588>, module = '__main__', name = 'SelectColumnsTransfomer' which is hinting me that the NumpyUnpickler doesn't see the SelectColumnsTransfomer even if in the test it is imported.

My test code

import pytest
from app.pipeline import * # the pipeline objects 
                          # SelectColumnsTransfomer and ToDummiesTransformer 
                          # are here!


@pytest.fixture(scope="module")
def clf():
    pipeline = joblib.load("persistence/data_pipeline.joblib")
    return clf

def test_fake(clf):
    assert True
like image 214
DarioB Avatar asked Nov 17 '22 22:11

DarioB


1 Answers

I had the same error message when I was trying to save a Pytorch class like this:

import torch.nn as nn

class custom(nn.Module):
    def __init__(self):
        super(custom, self).__init__()
        print("Class loaded")

model = custom()

And then using Joblib to dump this model like so:

from joblib import dump
dump(model, 'some_filepath.jobjib')

The issue was I was running the code above in a Kaggle kernel. And then downloading the dumped file and trying to load it with this script locally:

from joblib import load
model = load(model, 'some_filepath.jobjib')

The way I fixed the issue was to run all of these code snippets locally on my computer instead of creating the class and dumping it on Kaggle, but loading it on my local machine. Wanted to add this here because the comments on the answer by @DarioB confused me in their reference to a 'function' which didn't apply in my simpler case.

like image 70
Madhav Malhotra Avatar answered Dec 07 '22 23:12

Madhav Malhotra