Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Saving the state of the AdaGrad algorithm in Tensorflow

Tags:

tensorflow

I am trying to train a word2vec model, and want to use the embeddings for another application. As there might be extra data later, and my computer is slow when training, I would like my script to stop and resume training later.

To do this, I created a saver:

saver = tf.train.Saver({"embeddings": embeddings,"embeddings_softmax_weights":softmax_weights,"embeddings_softmax_biases":softmax_biases})

I save the embeddings, and softmax weights and biases so I can resume training later. (I assume that this is the correct way, but please correct me if I'm wrong).

Unfortunately when resuming training with this script the average loss seems to go up again.

My idea is that this can be attributed to the AdaGradOptimizer I'm using. Initially the outer product matrix will probably be set to all zero's, where after my training it will be filled (leading to a lower learning rate).

Is there a way to save the optimizer state to resume learning later?

like image 928
rmeertens Avatar asked Nov 11 '16 11:11

rmeertens


1 Answers

While TensorFlow seems to complain when you attempt to serialize an optimizer object directly (e.g. via tf.add_to_collection("optimizers", optimizer) and a subsequent call to tf.train.Saver().save()), you can save and restore the training update operation which is derived from the optimizer:

# init
if not load_model:
    optimizer = tf.train.AdamOptimizer(1e-4)
    train_step = optimizer.minimize(loss)
    tf.add_to_collection("train_step", train_step)
else:
    saver = tf.train.import_meta_graph(modelfile+ '.meta')
    saver.restore(sess, tf.train.latest_checkpoint('./'))
    train_step = tf.get_collection("train_step")[0]

# training loop
while training:
    if iteration % save_interval == 0:
        saver = tf.train.Saver()
        save_path = saver.save(sess, filepath)

I do not know of a way to get or set the parameters specific to an existing optimizer, so I do not have a direct way of verifying that the optimizer's internal state was restored, but training resumes with loss and accuracy comparable to when the snapshot was created. I would also recommend using the parameterless call to Saver() so that state variables not specifically mentioned will still be saved, although this might not be strictly necessary.

You may also wish to save the iteration or epoch number for later restoring, as detailed in this example: http://www.seaandsailor.com/tensorflow-checkpointing.html

like image 169
pygosceles Avatar answered Oct 10 '22 02:10

pygosceles