import tensorflow as tf
class MyModel(tf.keras.Model):
def __init__(self):
super(MyModel, self).__init__()
self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)
@tf.function
def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
x = self.dense1(enc_input)
return self.dense2(x)
x = tf.random.normal((10,20))
model = MyModel()
y = model(x, x, False, None, None, None)
tf.keras.models.save_model(model, '/saved')
when I try to save the model, throws an error even though i'm passing all the arguments.
tf__call() missing 4 required positional arguments: 'training', 'mask1', 'mask2', and 'mask3'
How to save the entire model and not just saving weights ?
I think doing this following change would work
#def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
def call(self, enc_input, dec_input, training=False, mask1=None, mask2=None, mask3=None):
After digging, i think there is a sanity check that happens on function arguments, if positional arguments are not specified, the parameters after x will be taken as **kwargs arguments(I am not really sure about this). but for the sake of it, if you don't want to set default mapping of arguments, you can just unpack them so each argument goes it its corresponding place as follow:
y = model(*[x,x,False,None,None,None])
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With