I'm trying to restart a model training in TensorFlow by picking up where it left off. I'd like to use the recently added (0.12+ I think) import_meta_graph()
so as to not reconstruct the graph.
I've seen solutions for this, e.g. Tensorflow: How to save/restore a model?, but I run into issues with AdamOptimizer, specifically I get a ValueError: cannot add op with name <my weights variable name>/Adam as that name is already used
error. This can be fixed by initializing, but then my model values are cleared!
There are other answers and some full examples out there, but they always seem older and so don't include the newer import_meta_graph()
approach, or don't have a non-tensor optimizer. The closest question I could find is tensorflow: saving and restoring session but there is no final clear cut solution and the example is pretty complicated.
Ideally I'd like a simple run-able example starting from scratch, stopping, then picking up again. I have something that works (below), but do also wonder if I'm missing something. Surely I'm not the only one doing this?
To save and restore your variables, all you need to do is to call the tf. train. Saver() at the end of you graph. This will create 3 files ( data , index , meta ) with a suffix of the step you saved your model.
To continue training a loaded model with checkpoints, we simply rerun the model. fit function with the callback still parsed. This however overwrites the currently saved best model, so make sure to change the checkpoint file path if this is undesired.
The model restoring is done using the tf. saved_model. loader and restores the saved variables, signatures, and assets in the scope of a session.
b) Checkpoint file: This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt.
Here is what I came up with from reading the docs, other similar solutions, and trial and error. It's a simple autoencoder on random data. If ran, then ran again, it will continue from where it left off (i.e. cost function on first run goes from ~0.5 -> 0.3 second run starts ~0.3). Unless I missed something, all of the saving, constructors, model building, add_to_collection there are needed and in a precise order, but there may be a simpler way.
And yes, loading the graph with import_meta_graph
isn't really needed here since the code is right above, but is what I want in my actual application.
from __future__ import print_function
import tensorflow as tf
import os
import math
import numpy as np
output_dir = "/root/Data/temp"
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt")
input_length = 10
encoded_length = 3
learning_rate = 0.001
n_epochs = 10
n_batches = 10
if not os.path.exists(model_checkpoint_file_base + ".meta"):
print("Making new")
brand_new = True
x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in")
W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length],
-1.0 / math.sqrt(input_length),
1.0 / math.sqrt(input_length)), name="W_enc")
b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc")
encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded")
W_dec = tf.transpose(W_enc, name="W_dec")
b_dec = tf.Variable(tf.zeros(input_length), name="b_dec")
decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded")
cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost")
saver = tf.train.Saver()
else:
print("Reloading existing")
brand_new = False
saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta")
g = tf.get_default_graph()
x_in = g.get_tensor_by_name("x_in:0")
cost = g.get_tensor_by_name("cost:0")
sess = tf.Session()
if brand_new:
optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
init = tf.global_variables_initializer()
sess.run(init)
tf.add_to_collection("optimizer", optimizer)
else:
saver.restore(sess, model_checkpoint_file_base)
optimizer = tf.get_collection("optimizer")[0]
for epoch_i in range(n_epochs):
for batch in range(n_batches):
batch = np.random.rand(50, input_length)
_, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch})
print("batch_cost:", curr_cost)
save_path = tf.train.Saver().save(sess, model_checkpoint_file_base)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With