Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Keras Copy Weights From One Model to Another

Using Keras from Tensorflow 1.4.1, how does one copy weights from one model to another?

As some background, I'm trying to implement a deep-q network (DQN) for Atari games following the DQN publication by DeepMind. My understanding is that the implementation uses two networks, Q and Q'. The weights of Q are trained using gradient descent, and then the weights are copied periodically to Q'.

Here's how I build Q and Q':

ACT_SIZE   = 4
LEARN_RATE = 0.0025
OBS_SIZE   = 128

def buildModel():
  model = tf.keras.models.Sequential()

  model.add(tf.keras.layers.Lambda(lambda x: x / 255.0, input_shape=OBS_SIZE))
  model.add(tf.keras.layers.Dense(128, activation="relu"))
  model.add(tf.keras.layers.Dense(128, activation="relu"))
  model.add(tf.keras.layers.Dense(ACT_SIZE, activation="linear"))
  opt = tf.keras.optimizers.RMSprop(lr=LEARN_RATE)

  model.compile(loss="mean_squared_error", optimizer=opt)

  return model

I call that twice to get Q and Q'.

I have an updateTargetModel method below that is my attempt at copying weights. The code runs fine, but my overall DQN implementation is failing. I'm really just trying to verify if this is a valid way of copying weights from one network to another.

def updateTargetModel(model, targetModel):
  modelWeights       = model.trainable_weights
  targetModelWeights = targetModel.trainable_weights

  for i in range(len(targetModelWeights)):
    targetModelWeights[i].assign(modelWeights[i])

There's another question here that discusses saving and loading weights to and from disk (Tensorflow Copy Weights Issue), but there's no accepted answer. There is also a question about loading weights from individual layers (Copying weights from one Conv2D layer to another), but I'm wanting to copy the entire model's weights.

like image 396
benbotto Avatar asked Jan 31 '18 17:01

benbotto


1 Answers

Actually what you've done is much more than simply copying weights. You made these two models identical all the time. Every time you update one model - the second one is also updated - as both models have the same weights variables.

If you want to just copy weights - the simplest way is by this command:

target_model.set_weights(model.get_weights())  
like image 113
Marcin Możejko Avatar answered Sep 21 '22 13:09

Marcin Możejko