Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to pickle Keras custom layer?

I write a custom layer class extends by Layer class, then I want to pickle the history for further analysis, but when I reload the pickle object from file, python raise an error:

Unknown Layer: Attention.

So, how can I fix it?

I have both tried get_config, __getstate__ and __setstate__, but it failed. I just want to pickle the keras history, but not the model, so please don't tell me the save model methods with custom_object parameters.

like image 333
Tandy Avatar asked Jan 07 '19 08:01


People also ask

Can you pickle a Keras model?

As of now, Keras models are pickle-able. But we still recommend using model. save() to save model to disk.

How do I save a custom model in Tensorflow?

Using save_weights() method Now you can simply save the weights of all the layers using the save_weights() method. It saves the weights of the layers contained in the model. It is advised to use the save() method to save h5 models instead of save_weights() method for saving a model using tensorflow.

When should you create a custom layer versus a custom model?

If you are building a new model architecture using existing keras/tf layers then build a custom model. If you are implementing your own custom tensor operations with in a layer, then build a custom layer.

1 Answers

This problem occurs because when dumping the history, it fails to dump the full model. So when loading it, it cannot find the custom class.

I've noticed that the keras.callbacks.History object has an attribute model, and the incomplete dump of it is the cause of this problem.

And you said:

I just want to pickle the keras history, but not the model

So following is a workaround:

hist = model.fit(X, Y, ...)
hist.model = None

By just setting the model attribute to None, and you can dump and load your history object successfully!

Following is the MVCE:

from keras.models import Sequential
from keras.layers import Conv2D, Dense, Flatten, Layer
import keras.backend as K
import numpy as np
import pickle

# MyLayer from https://keras.io/layers/writing-your-own-keras-layers/
class MyLayer(Layer):

    def __init__(self, output_dim, **kwargs):
        self.output_dim = output_dim
        super(MyLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        # Create a trainable weight variable for this layer.
        self.kernel = self.add_weight(name='kernel', 
                                      shape=(input_shape[1], self.output_dim),
        super(MyLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        return K.dot(x, self.kernel)

    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)

model = Sequential()
model.add(Conv2D(filters=32, kernel_size=(3,3), input_shape=(28,28,3), activation='sigmoid'))
model.add(Dense(3, activation='softmax'))

model.compile(loss='sparse_categorical_crossentropy', metrics=['accuracy'], optimizer='adam')


X = np.random.randn(64, 28, 28, 3)
Y = np.random.randint(0, high=2, size=(64,1))

hist = model.fit(X, Y, batch_size=8)

hist.model = None

with open('hist.pkl', 'wb') as f:
    pickle.dump(hist, f)

with open('hist.pkl', 'rb') as f:
    hist_reloaded = pickle.load(f)


The output:

{'acc': [0.484375], 'loss': [6.140302091836929]}

{'acc': [0.484375], 'loss': [6.140302091836929]}

P.S. If one wants to save keras model with custom layer, this should be helpful.

like image 71
keineahnung2345 Avatar answered Oct 07 '22 03:10
