I am using sklearn
's Pipeline
and FunctionTransformer
with a custom function
from sklearn.externals import joblib
from sklearn.preprocessing import FunctionTransformer
from sklearn.pipeline import Pipeline
This is my code:
def f(x):
return x*2
pipe = Pipeline([("times_2", FunctionTransformer(f))])
joblib.dump(pipe, "pipe.joblib")
del pipe
del f
pipe = joblib.load("pipe.joblib") # Causes an exception
And I get this error:
AttributeError: module '__ main__' has no attribute 'f'
How can this be resolved ?
Note that this issue occurs also in pickle
A FunctionTransformer forwards its X (and optionally y) arguments to a user-defined function or function object and returns the result of this function. This is useful for stateless transformations such as taking the log of frequencies, doing custom scaling, etc.
A transformer is a device that transfers electrical energy from one circuit to another. Mutual induction connects two circuits. It is also utilized for electrical power transmission via electromagnetic induction. Electric power is transferred without even any frequency modification.
I was able to hack a solution using the marshal
module (in addition to pickle
) and override the magic methods getstate
and setstate
used by pickle
.
import marshal
from types import FunctionType
from sklearn.base import BaseEstimator, TransformerMixin
class MyFunctionTransformer(BaseEstimator, TransformerMixin):
def __init__(self, f):
self.func = f
def __call__(self, X):
return self.func(X)
def __getstate__(self):
self.func_name = self.func.__name__
self.func_code = marshal.dumps(self.func.__code__)
del self.func
return self.__dict__
def __setstate__(self, d):
d["func"] = FunctionType(marshal.loads(d["func_code"]), globals(), d["func_name"])
del d["func_name"]
del d["func_code"]
self.__dict__ = d
def fit(self, X, y=None):
return self
def transform(self, X):
return self.func(X)
Now, if we use MyFunctionTransformer
instead of FunctionTransformer
, the code works as expected:
from sklearn.externals import joblib
from sklearn.pipeline import Pipeline
@MyFunctionTransformer
def my_transform(x):
return x*2
pipe = Pipeline([("times_2", my_transform)])
joblib.dump(pipe, "pipe.joblib")
del pipe
del my_transform
pipe = joblib.load("pipe.joblib")
The way this works, is by deleting the function f
from the pickle, and instead marshaling
its code, and its name.
dill
also looks like a good alternative to marshaling
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With