Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python TensorFlow: How to restart training with optimizer and import_meta_graph?

Tags:

tensorflow

I'm trying to restart a model training in TensorFlow by picking up where it left off. I'd like to use the recently added (0.12+ I think) import_meta_graph() so as to not reconstruct the graph.

I've seen solutions for this, e.g. Tensorflow: How to save/restore a model?, but I run into issues with AdamOptimizer, specifically I get a ValueError: cannot add op with name <my weights variable name>/Adam as that name is already used error. This can be fixed by initializing, but then my model values are cleared!

There are other answers and some full examples out there, but they always seem older and so don't include the newer import_meta_graph() approach, or don't have a non-tensor optimizer. The closest question I could find is tensorflow: saving and restoring session but there is no final clear cut solution and the example is pretty complicated.

Ideally I'd like a simple run-able example starting from scratch, stopping, then picking up again. I have something that works (below), but do also wonder if I'm missing something. Surely I'm not the only one doing this?

like image 270
Ken Avatar asked Apr 06 '17 00:04

Ken


People also ask

How do I save and restore a TensorFlow model?

To save and restore your variables, all you need to do is to call the tf. train. Saver() at the end of you graph. This will create 3 files ( data , index , meta ) with a suffix of the step you saved your model.

How do you continue training in TensorFlow?

To continue training a loaded model with checkpoints, we simply rerun the model. fit function with the callback still parsed. This however overwrites the currently saved best model, so make sure to change the checkpoint file path if this is undesired.

What does TensorFlow use to save and restore model parameters on the disk?

The model restoring is done using the tf. saved_model. loader and restores the saved variables, signatures, and assets in the scope of a session.

What is Ckpt file in TensorFlow?

b) Checkpoint file: This is a binary file which contains all the values of the weights, biases, gradients and all the other variables saved. This file has an extension .ckpt.


1 Answers

Here is what I came up with from reading the docs, other similar solutions, and trial and error. It's a simple autoencoder on random data. If ran, then ran again, it will continue from where it left off (i.e. cost function on first run goes from ~0.5 -> 0.3 second run starts ~0.3). Unless I missed something, all of the saving, constructors, model building, add_to_collection there are needed and in a precise order, but there may be a simpler way.

And yes, loading the graph with import_meta_graph isn't really needed here since the code is right above, but is what I want in my actual application.

from __future__ import print_function
import tensorflow as tf
import os
import math
import numpy as np

output_dir = "/root/Data/temp"
model_checkpoint_file_base = os.path.join(output_dir, "model.ckpt")

input_length = 10
encoded_length = 3
learning_rate = 0.001
n_epochs = 10
n_batches = 10
if not os.path.exists(model_checkpoint_file_base + ".meta"):
    print("Making new")
    brand_new = True

    x_in = tf.placeholder(tf.float32, [None, input_length], name="x_in")
    W_enc = tf.Variable(tf.random_uniform([input_length, encoded_length],
                                          -1.0 / math.sqrt(input_length),
                                          1.0 / math.sqrt(input_length)), name="W_enc")
    b_enc = tf.Variable(tf.zeros(encoded_length), name="b_enc")
    encoded = tf.nn.tanh(tf.matmul(x_in, W_enc) + b_enc, name="encoded")
    W_dec = tf.transpose(W_enc, name="W_dec")
    b_dec = tf.Variable(tf.zeros(input_length), name="b_dec")
    decoded = tf.nn.tanh(tf.matmul(encoded, W_dec) + b_dec, name="decoded")
    cost = tf.sqrt(tf.reduce_mean(tf.square(decoded - x_in)), name="cost")

    saver = tf.train.Saver()
else:
    print("Reloading existing")
    brand_new = False
    saver = tf.train.import_meta_graph(model_checkpoint_file_base + ".meta")
    g = tf.get_default_graph()
    x_in = g.get_tensor_by_name("x_in:0")
    cost = g.get_tensor_by_name("cost:0")


sess = tf.Session()
if brand_new:
    optimizer = tf.train.AdamOptimizer(learning_rate).minimize(cost)
    init = tf.global_variables_initializer()
    sess.run(init)
    tf.add_to_collection("optimizer", optimizer)
else:
    saver.restore(sess, model_checkpoint_file_base)
    optimizer = tf.get_collection("optimizer")[0]

for epoch_i in range(n_epochs):
    for batch in range(n_batches):
        batch = np.random.rand(50, input_length)
        _, curr_cost = sess.run([optimizer, cost], feed_dict={x_in: batch})
        print("batch_cost:", curr_cost)
        save_path = tf.train.Saver().save(sess, model_checkpoint_file_base)
like image 149
Ken Avatar answered Sep 17 '22 22:09

Ken