I am trying to load a RNN model architecture trained in Keras using keras.models.model_from_json and I am getting the mentioned error
with open('model_architecture.json', 'r') as f:
model = model_from_json(f.read(), custom_objects={'AttLayer':AttLayer})
# Load weights into the new model
model.load_weights('model_weights.h5')
Here is the custom layer I am using
class AttLayer(Layer):
def __init__(self, attention_dim):
self.init = initializers.get('normal')
self.supports_masking = True
self.attention_dim = attention_dim
super(AttLayer, self).__init__()
def build(self, input_shape):
assert len(input_shape) == 3
self.W = K.variable(self.init((input_shape[-1], self.attention_dim)))
self.b = K.variable(self.init((self.attention_dim, )))
self.u = K.variable(self.init((self.attention_dim, 1)))
self.trainable_weights = [self.W, self.b, self.u]
super(AttLayer, self).build(input_shape)
def compute_mask(self, inputs, mask=None):
return None
def call(self, x, mask=None):
# size of x :[batch_size, sel_len, attention_dim]
# size of u :[batch_size, attention_dim]
# uit = tanh(xW+b)
uit = K.tanh(K.bias_add(K.dot(x, self.W), self.b))
ait = K.dot(uit, self.u)
ait = K.squeeze(ait, -1)
ait = K.exp(ait)
if mask is not None:
# Cast the mask to floatX to avoid float64 upcasting in theano
ait *= K.cast(mask, K.floatx())
ait /= K.cast(K.sum(ait, axis=1, keepdims=True) + K.epsilon(), K.floatx())
ait = K.expand_dims(ait)
weighted_input = x * ait
output = K.sum(weighted_input, axis=1)
return output
def compute_output_shape(self, input_shape):
return (input_shape[0], input_shape[-1])
def get_config(self):
config = {'attention_dim': self.attention_dim}
base_config = super(AttLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
error:
File "scripts/Classifier.py", line 254, in test
model = model_from_json(f.read(), custom_objects={'AttLayer':AttLayer})
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/models.py", line 345, in model_from_json
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
printable_module_name='layer')
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
list(custom_objects.items())))
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2489, in from_config
process_layer(layer_data)
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2475, in process_layer
custom_objects=custom_objects)
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
printable_module_name='layer')
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
list(custom_objects.items())))
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/wrappers.py", line 100, in from_config
custom_objects=custom_objects)
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
printable_module_name='layer')
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 139, in deserialize_keras_object
list(custom_objects.items())))
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2489, in from_config
process_layer(layer_data)
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 2475, in process_layer
custom_objects=custom_objects)
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/layers/__init__.py", line 54, in deserialize
printable_module_name='layer')
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/utils/generic_utils.py", line 141, in deserialize_keras_object
return cls.from_config(config['config'])
File "/home/biswadip/.local/lib/python2.7/site-packages/keras/engine/topology.py", line 1254, in from_config
return cls(**config)
TypeError: __init__() got an unexpected keyword argument 'trainable'
Versions:
Keras==2.0.8
tensorflow==1.4.1
I tried training and loading using different versions, but with no luck. Finally I removed 'trainable' and 'name' (key value pairs)from my custom layer detail in the model architecture file(model_architecture.json) and model seems to be loading without any error. But this looks like a fix and I have to do this every time I train the model.
I think you missed a small detail in your layer definition. You layers' __init__
method should take keyword arguments (**kwargs
) and you should pass these keyword arguments to the parent class __init__
, like this:
class AttLayer(Layer):
def __init__(self, attention_dim, **kwargs):
self.init = initializers.get('normal')
self.supports_masking = True
self.attention_dim = attention_dim
super(AttLayer, self).__init__(**kwargs)
This way any generic layer parameter will be correctly passed to the parent class, in your case, the trainable
flag.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With