Using Keras from Tensorflow 1.4.1, how does one copy weights from one model to another?
As some background, I'm trying to implement a deep-q network (DQN) for Atari games following the DQN publication by DeepMind. My understanding is that the implementation uses two networks, Q and Q'. The weights of Q are trained using gradient descent, and then the weights are copied periodically to Q'.
Here's how I build Q and Q':
ACT_SIZE = 4
LEARN_RATE = 0.0025
OBS_SIZE = 128
def buildModel():
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Lambda(lambda x: x / 255.0, input_shape=OBS_SIZE))
model.add(tf.keras.layers.Dense(128, activation="relu"))
model.add(tf.keras.layers.Dense(128, activation="relu"))
model.add(tf.keras.layers.Dense(ACT_SIZE, activation="linear"))
opt = tf.keras.optimizers.RMSprop(lr=LEARN_RATE)
model.compile(loss="mean_squared_error", optimizer=opt)
return model
I call that twice to get Q and Q'.
I have an updateTargetModel
method below that is my attempt at copying weights. The code runs fine, but my overall DQN implementation is failing. I'm really just trying to verify if this is a valid way of copying weights from one network to another.
def updateTargetModel(model, targetModel):
modelWeights = model.trainable_weights
targetModelWeights = targetModel.trainable_weights
for i in range(len(targetModelWeights)):
targetModelWeights[i].assign(modelWeights[i])
There's another question here that discusses saving and loading weights to and from disk (Tensorflow Copy Weights Issue), but there's no accepted answer. There is also a question about loading weights from individual layers (Copying weights from one Conv2D layer to another), but I'm wanting to copy the entire model's weights.
Actually what you've done is much more than simply copying weights. You made these two models identical all the time. Every time you update one model - the second one is also updated - as both models have the same weights
variables.
If you want to just copy weights - the simplest way is by this command:
target_model.set_weights(model.get_weights())
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With