Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Save and load keras subclassed models

I am trying to save and load the CNN Encoder and RNN decoder from TF Tutorial on image captioning: https://www.tensorflow.org/tutorials/text/image_captioning. Since these are subclassed Keras models and not Functional or Sequential one, so I could not use model.save and model.load directly.

Instead I had to use model.save_weights and model.load_weights. Problem is model.load_weights can be called only after model.build and model.build requires input_shape parameter which has to be tuple not list of tuples. For our RNN decoder however, we have multiple inputs. Keras docs specify no way to call model.build with multiple inputs.

Is there any other way to load model.

Eventually I want to have a smaller python script which can load model weights and do inference. That script should not have to train.

Colab: https://colab.research.google.com/drive/12YtCH2X0pwIBBXPW0TXmeA520MyVv9AF

like image 626
Piyush Singh Avatar asked Oct 19 '25 02:10

Piyush Singh


1 Answers

Here is how I managed to work around the problem. Not a very good solution but works! First you save each weight matrix in .npy files:

for i, layer in enumerate(encoder.layers):
  print("Layer %s" %i, layer.name)
  for j, w in enumerate(layer.weights):
     print(w.shape)
     np.save("encoder_layer_weights/layer_%s_%s_weights_%s.npy" %(i, layer.name, j), w.numpy())


for i, layer in enumerate(decoder.layers):
  print("Layer %s" %i, layer.name)
  for j, w in enumerate(layer.weights):
     print(w.shape)
     np.save("decoder_layer_weights/layer_%s_%s_weights_%s.npy" %(i, layer.name, j), w.numpy())

Then you re-create the subclassed models, but this time you use initializers for each weight in each layer. This has to be done carefully because if there is a shape mis-match your model wont compile.

class CNN_Encoder(tf.keras.Model):
    # Since you have already extracted the features and dumped it using pickle
    # This encoder passes those features through a Fully connected layer
    def __init__(self, embedding_dim):
        super(CNN_Encoder, self).__init__()
        # shape after fc == (batch_size, 64, embedding_dim)
        C = tf.keras.initializers.Constant
        w1, w2 = [np.load("encoder_layer_weights/layer_%s_%s_weights_%s.npy" %(0, "dense", j)) \
                                      for j in range(2)]
        self.fc = tf.keras.layers.Dense(embedding_dim, kernel_initializer=C(w1), bias_initializer=C(w2))

    def call(self, x):
        x = self.fc(x)
        x = tf.nn.relu(x)
        return x


class BahdanauAttention(tf.keras.Model):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        C = tf.keras.initializers.Constant
        w1, w2, w3, w4, w5, w6 = [np.load("decoder_layer_weights/layer_%s_%s_weights_%s.npy" %(4, "bahdanau_attention", j)) \
                                  for j in range(6)]
        self.W1 = tf.keras.layers.Dense(units, kernel_initializer=C(w1), bias_initializer=C(w2))
        self.W2 = tf.keras.layers.Dense(units, kernel_initializer=C(w3), bias_initializer=C(w4))
        self.V = tf.keras.layers.Dense(1, kernel_initializer=C(w5), bias_initializer=C(w6))

    def call(self, features, hidden):
        # features(CNN_encoder output) shape == (batch_size, 64, embedding_dim)

        # hidden shape == (batch_size, hidden_size)
        # hidden_with_time_axis shape == (batch_size, 1, hidden_size)
        hidden_with_time_axis = tf.expand_dims(hidden, 1)

        # score shape == (batch_size, 64, hidden_size)
        score = tf.nn.tanh(self.W1(features) + self.W2(hidden_with_time_axis))

        # attention_weights shape == (batch_size, 64, 1)
        # you get 1 at the last axis because you are applying score to self.V
        attention_weights = tf.nn.softmax(self.V(score), axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * features
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights


class RNN_Decoder(tf.keras.Model):
    def __init__(self, embedding_dim, units, vocab_size):
        super(RNN_Decoder, self).__init__()
        self.units = units

        C = tf.keras.initializers.Constant
        w_emb = np.load("decoder_layer_weights/layer_%s_%s_weights_%s.npy" %(0, "embedding", 0))
        w_gru_1, w_gru_2, w_gru_3 = [np.load("decoder_layer_weights/layer_%s_%s_weights_%s.npy" %(1, "gru", j)) for j in range(3)]
        w1, w2 = [np.load("decoder_layer_weights/layer_%s_%s_weights_%s.npy" %(2, "dense_1", j)) for j in range(2)]
        w3, w4 = [np.load("decoder_layer_weights/layer_%s_%s_weights_%s.npy" %(3, "dense_2", j)) for j in range(2)]

        self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim, embeddings_initializer=C(w_emb))
        self.gru = tf.keras.layers.GRU(self.units,
                                       return_sequences=True,
                                       return_state=True,
                                       kernel_initializer=C(w_gru_1),
                                       recurrent_initializer=C(w_gru_2),
                                       bias_initializer=C(w_gru_3)
                                       )
        self.fc1 = tf.keras.layers.Dense(self.units, kernel_initializer=C(w1), bias_initializer=C(w2))
        self.fc2 = tf.keras.layers.Dense(vocab_size, kernel_initializer=C(w3), bias_initializer=C(w4))

        self.attention = BahdanauAttention(self.units)

    def call(self, x, features, hidden):
        # defining attention as a separate model
        context_vector, attention_weights = self.attention(features, hidden)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(x)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the GRU
        output, state = self.gru(x)

        # shape == (batch_size, max_length, hidden_size)
        x = self.fc1(output)

        # x shape == (batch_size * max_length, hidden_size)
        x = tf.reshape(x, (-1, x.shape[2]))

        # output shape == (batch_size * max_length, vocab)
        x = self.fc2(x)

        return x, state, attention_weights

    def reset_state(self, batch_size):
        return tf.zeros((batch_size, self.units))

Finally you instantiate the Encoder and Decoder classes as you normally would:

encoder = CNN_Encoder(embedding_dim)
decoder = RNN_Decoder(embedding_dim, units, vocab_size)
like image 121
Piyush Singh Avatar answered Oct 20 '25 15:10

Piyush Singh