Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save a keras subclassed model with positional parameters in Call() method?

import tensorflow as tf

class MyModel(tf.keras.Model):

    def __init__(self):
        super(MyModel, self).__init__()
        self.dense1 = tf.keras.layers.Dense(4, activation=tf.nn.relu)
        self.dense2 = tf.keras.layers.Dense(5, activation=tf.nn.softmax)

    @tf.function
    def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
        x = self.dense1(enc_input)
        return self.dense2(x)

x = tf.random.normal((10,20))

model = MyModel()

y = model(x, x, False, None, None, None)

tf.keras.models.save_model(model, '/saved')

when I try to save the model, throws an error even though i'm passing all the arguments.

tf__call() missing 4 required positional arguments: 'training', 'mask1', 'mask2', and 'mask3'

How to save the entire model and not just saving weights ?

like image 787
Uchiha Madara Avatar asked Oct 19 '20 10:10

Uchiha Madara


1 Answers

I think doing this following change would work

 #def call(self, enc_input, dec_input, training, mask1, mask2, mask3):
 def call(self, enc_input, dec_input, training=False, mask1=None, mask2=None, mask3=None):

After digging, i think there is a sanity check that happens on function arguments, if positional arguments are not specified, the parameters after x will be taken as **kwargs arguments(I am not really sure about this). but for the sake of it, if you don't want to set default mapping of arguments, you can just unpack them so each argument goes it its corresponding place as follow:

y = model(*[x,x,False,None,None,None])

like image 145
Eliethesaiyan Avatar answered Oct 08 '22 09:10

Eliethesaiyan