Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Pickling monkey-patched Keras model for use in PySpark

The overall goal of what I am trying to achieve is sending a Keras model to each spark worker so that I can use the model within a UDF applied to a column of a DataFrame. To do this, the Keras model will need to be picklable.

It seems like a lot of people have had success at pickling keras models by monkey patching the Model class as shown by the link below:


However, I have not seen any example of how to do this in tandem with Spark. My first attempt just ran the make_keras_picklable() function on in the driver which allowed me to pickle and unpickle the model in the driver, but I could not pickle the model in UDFs.

def make_keras_picklable():
    "Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"


model = Sequential() # etc etc

def score(case):
    score = model.predict(case)

def scoreUDF = udf(score, ArrayType(FloatType()))

The error I get suggests that the unpickling the model in the UDF is not using the monkey-patched Model class.

AttributeError: 'Sequential' object has no attribute '_built'

It looks like another user was running into similar errors in this SO post and the answer was to "run make_keras_picklable() on each worker as well." No example of how to do this was given.

My question is: What is the appropriate way to call make_keras_picklable() on all workers?

I tried using broadcast() (see below) but got the same error as above.

def make_keras_picklable():
    "Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"


model = Sequential() # etc etc

def score(case):
    score = model.predict(case)

def scoreUDF = udf(score, ArrayType(FloatType()))
like image 220
Erp12 Avatar asked Apr 24 '18 16:04


2 Answers

Khaled Zaouk over on the Spark user mailing list helped me out by suggesting that the make_keras_picklable() be changed to a wrapper class. This worked great!

import tempfile

import tensorflow as tf

class KerasModelWrapper:
    """Source: https://zachmoshe.com/2017/04/03/pickling-keras-models.html"""

    def __init__(self, model):
        self.model = model

    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
            tf.keras.models.save_model(self.model, fd.name, overwrite=True)
            model_str = fd.read()
        d = {"model_str": model_str}
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix=".hdf5", delete=True) as fd:
            self.model = tf.keras.models.load_model(fd.name)

Of course this could probably be made a little bit more elegant by implementing this as a subclass of Keras's Model class or maybe a PySpark.ML transformer/estimator.

like image 200
Erp12 Avatar answered Sep 21 '22 20:09


With the same idea of Erp12 , you can use this class to wrap a keras model, creating all its attributes dynamically, with the same spirit of the decorator pattern and extending a keras model, as Erp12 suggested.

import tempfile
import tensorflow as tf

class PicklableKerasModel(tf.keras.models.Model):

    def __init__(self, model):
        self._model = model

    def __getstate__(self):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            tf.keras.models.save_model(self._model, fd.name, overwrite=True)
            model_str = fd.read()
        d = {'model_str': model_str}
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            model = tf.keras.models.load_model(fd.name)
        self._model = model

    def __getattr__(self, name):
        return getattr(self.__dict__['_model'], name)

    def __setattr__(self, name, value):
        if name == '_model':
            self.__dict__['_model'] = value
            setattr(self.__dict__['_model'], name, value)

    def __delattr__(self, name):
        delattr(self.__dict__['_model'], name)

Then you just can use the model wrapping your keras model like:

model = Sequential() # etc etc

picklable_model = PicklableKerasModel(model)
like image 45
Gerardo González Seco Avatar answered Sep 23 '22 20:09

Gerardo González Seco