I made a custom layer in keras for reshaping the outputs of a CNN before feeding to ConvLSTM2D layer
class TemporalReshape(Layer):
def __init__(self,batch_size,num_patches):
super(TemporalReshape,self).__init__()
self.batch_size = batch_size
self.num_patches = num_patches
def call(self,inputs):
nshape = (self.batch_size,self.num_patches)+inputs.shape[1:]
return tf.reshape(inputs, nshape)
def get_config(self):
config = super().get_config().copy()
config.update({'batch_size':self.batch_size,'num_patches':self.num_patches})
return config
When I try to load the best model using
model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
I get the error
TypeError Traceback (most recent call last)
<ipython-input-83-40b46da33e91> in <module>()
----> 1 model = tf.keras.models.load_model('/content/saved_models/model_best.h5',custom_objects={'TemporalReshape':TemporalReshape})
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
180 if (h5py is not None and (
181 isinstance(filepath, h5py.File) or h5py.is_hdf5(filepath))):
--> 182 return hdf5_format.load_model_from_hdf5(filepath, custom_objects, compile)
183
184 filepath = path_to_string(filepath)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/hdf5_format.py in load_model_from_hdf5(filepath, custom_objects, compile)
176 model_config = json.loads(model_config.decode('utf-8'))
177 model = model_config_lib.model_from_config(model_config,
--> 178 custom_objects=custom_objects)
179
180 # set weights
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/saving/model_config.py in model_from_config(config, custom_objects)
53 '`Sequential.from_config(config)`?')
54 from tensorflow.python.keras.layers import deserialize # pylint: disable=g-import-not-at-top
---> 55 return deserialize(config, custom_objects=custom_objects)
56
57
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
173 module_objects=LOCAL.ALL_OBJECTS,
174 custom_objects=custom_objects,
--> 175 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
356 custom_objects=dict(
357 list(_GLOBAL_CUSTOM_OBJECTS.items()) +
--> 358 list(custom_objects.items())))
359 with CustomObjectScope(custom_objects):
360 return cls.from_config(cls_config)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in from_config(cls, config, custom_objects)
615 """
616 input_tensors, output_tensors, created_layers = reconstruct_from_config(
--> 617 config, custom_objects)
618 model = cls(inputs=input_tensors, outputs=output_tensors,
619 name=config.get('name'))
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in reconstruct_from_config(config, custom_objects, created_layers)
1202 # First, we create all layers and enqueue nodes to be processed
1203 for layer_data in config['layers']:
-> 1204 process_layer(layer_data)
1205 # Then we process nodes in order of layer depth.
1206 # Nodes that cannot yet be processed (if the inbound node
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/functional.py in process_layer(layer_data)
1184 from tensorflow.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
1185
-> 1186 layer = deserialize_layer(layer_data, custom_objects=custom_objects)
1187 created_layers[layer_name] = layer
1188
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/serialization.py in deserialize(config, custom_objects)
173 module_objects=LOCAL.ALL_OBJECTS,
174 custom_objects=custom_objects,
--> 175 printable_module_name='layer')
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/utils/generic_utils.py in deserialize_keras_object(identifier, module_objects, custom_objects, printable_module_name)
358 list(custom_objects.items())))
359 with CustomObjectScope(custom_objects):
--> 360 return cls.from_config(cls_config)
361 else:
362 # Then `cls` may be a function returning a class.
/usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/engine/base_layer.py in from_config(cls, config)
695 A layer instance.
696 """
--> 697 return cls(**config)
698
699 def compute_output_shape(self, input_shape):
TypeError: __init__() got an unexpected keyword argument 'name'
When building the model, I used the custom layer like the following :
x = TemporalReshape(batch_size = 8, num_patches = 16)(x)
What is causing the error and how to load the model without this error?
Based on the error message only, I would suggest putting **kwargs
in __init__
. This object will then accept any other keyword argument that you haven't included.
def __init__(self, batch_size, num_patches, **kwargs):
super(TemporalReshape, self).__init__()
self.batch_size = batch_size
self.num_patches = num_patches
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With