Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving an sklearn `FunctionTransformer` with the function it wraps

I am using sklearn's Pipeline and FunctionTransformer with a custom function

from sklearn.externals import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline

This is my code:

def f(x):
    return x*2
pipe = Pipeline([("times_2", FunctionTransformer(f))])
joblib.dump(pipe, "pipe.joblib")
del pipe
del f
pipe = joblib.load("pipe.joblib") # Causes an exception

And I get this error:

AttributeError: module '__ main__' has no attribute 'f'

How can this be resolved ?

Note that this issue occurs also in pickle

like image 722
Uri Goren Avatar asked Jan 02 '19 20:01

Uri Goren


People also ask

What is FunctionTransformer in Sklearn?

A FunctionTransformer forwards its X (and optionally y) arguments to a user-defined function or function object and returns the result of this function. This is useful for stateless transformations such as taking the log of frequencies, doing custom scaling, etc.

What does function Transformer do?

A transformer is a device that transfers electrical energy from one circuit to another. Mutual induction connects two circuits. It is also utilized for electrical power transmission via electromagnetic induction. Electric power is transferred without even any frequency modification.


1 Answers

I was able to hack a solution using the marshal module (in addition to pickle) and override the magic methods getstate and setstate used by pickle.

import marshal
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin

class MyFunctionTransformer(BaseEstimator, TransformerMixin):
    def __init__(self, f):
        self.func = f
    def __call__(self, X):
        return self.func(X)
    def __getstate__(self):
        self.func_name = self.func.__name__
        self.func_code = marshal.dumps(self.func.__code__)
        del self.func
        return self.__dict__
    def __setstate__(self, d):
        d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
        del d["func_name"]
        del d["func_code"]
        self.__dict__ = d
    def fit(self, X, y=None):
        return self
    def transform(self, X):
        return self.func(X)

Now, if we use MyFunctionTransformer instead of FunctionTransformer, the code works as expected:

from sklearn.externals import joblib
from sklearn.pipeline import Pipeline

@MyFunctionTransformer
def my_transform(x):
    return x*2
pipe = Pipeline([("times_2", my_transform)])
joblib.dump(pipe, "pipe.joblib")
del pipe
del my_transform
pipe = joblib.load("pipe.joblib")

The way this works, is by deleting the function f from the pickle, and instead marshaling its code, and its name.

dill also looks like a good alternative to marshaling

like image 159
Uri Goren Avatar answered Oct 14 '22 01:10

Uri Goren