Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: Serializing a Masking Layer for save/load

Tags:

python

keras

So I have a custom layer in Keras that uses a Mask in it.

To get it to work with save/load I need to serialize the Mask correctly. So this standard code doesn't work:

def get_config(self):
    config =  {'mask': self.mask}
    base_config = super(Mixing, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

where mask is a reference to the Masking Layer.

I'm not sure how to serialize Masking (or Keras Layers in general). Can anyone help?

like image 499
anon Avatar asked Oct 20 '25 10:10

anon


1 Answers

You can implement the same serializing methods as the built-in Wrapper class.

def get_config(self):
    config = {'layer': {'class_name': self.layer.__class__.__name__,
                        'config': self.layer.get_config()}}
    base_config = super(Wrapper, self).get_config()
    return dict(list(base_config.items()) + list(config.items()))

@classmethod
def from_config(cls, config, custom_objects=None):
    from . import deserialize as deserialize_layer
    layer = deserialize_layer(config.pop('layer'),
                              custom_objects=custom_objects)
    return cls(layer, **config)

During serialization, in get_config, the inner layer's class name and config are saved in config['layer'].

In from_config, the inner layer is deserialized with deserialize_layer using config['layer'].

like image 118
Yu-Yang Avatar answered Oct 22 '25 00:10

Yu-Yang