Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keyboard interrupt tensorflow run and save at that point

Is there a way to break a tensorflow session by keyboard interrupt and have the option to save the model at that point? I currently leave the session running overnight but need to stop it so I can free memory for pc use during the day. As the training progresses each epoch gets slower so there are times when I may have to wait hours for the next scheduled save in the program. I'd like the functionality of being able to break into the run at any time and save from that point. I can't even find if that's possible. Would appreciate a pointer.

like image 710
d8sconz Avatar asked Dec 08 '17 01:12

d8sconz


Video Answer


1 Answers

One option is to subclass the tf.Session object and create an __exit__ function that saves out the current state as the keyboard interrupt passes through. This would only work when the new object is called as part of a with block.

Here is the subclass:

import tensorflow as tf

class SessionWithExitSave(tf.Session):
    def __init__(self, *args, saver=None, exit_save_path=None, **kwargs):
        self.saver = saver
        self.exit_save_path = exit_save_path
        super().__init__(*args, **kwargs)

    def __exit__(self, exc_type, exc_value, exc_tb):
        if exc_type is KeyboardInterrupt:
            if self.saver:
                self.saver.save(self, self.exit_save_path)
                print('Output saved to: "{}./*"'.format(self.exit_save_path))
        super().__exit__(exc_type, exc_value, exc_tb)

Example usage from the TensorFlow mnist walkthrough.

import tensorflow as tf
import datetime as dt
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('U:/mnist/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)

saver = tf.train.Saver()

with SessionWithExitSave(
        saver=saver, 
        exit_save_path='./tf-saves/_lastest.ckpt') as sess:
    sess.run(tf.global_variables_initializer())
    total_epochs = 50
    for epoch in range(1, total_epochs+1):
        for _ in range(1000):
            batch_xs, batch_ys = mnist.train.next_batch(100)
            sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
        # Test trained model
        correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
        accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

        print(f'Epoch {epoch} of {total_epochs} :: accuracy = ', end='')
        print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
        save_time = dt.datetime.now().strftime('%Y%m%d-%H.%M.%S')
        saver.save(sess, f'./tf-saves/mnist-{save_time}.ckpt')

I let this run for 10 epochs before sending an interrupt signal from the keyboard. Here is the output:

Epoch 1 of 50 :: accuracy = 0.9169
Epoch 2 of 50 :: accuracy = 0.919
Epoch 3 of 50 :: accuracy = 0.9205
Epoch 4 of 50 :: accuracy = 0.9221
Epoch 5 of 50 :: accuracy = 0.92
Epoch 6 of 50 :: accuracy = 0.9229
Epoch 7 of 50 :: accuracy = 0.9234
Epoch 8 of 50 :: accuracy = 0.9234
Epoch 9 of 50 :: accuracy = 0.9252
Epoch 10 of 50 :: accuracy = 0.9248
Output saved to: "./tf-saves/_lastest.ckpt./*"
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
...
--> 768   elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
    769     return item[1]._is_present_in_parent
    770   else:
KeyboardInterrupt:

And indeed, I do have all of the saved files include the save from the keyboard interrupt that was sent to the system.

import os

os.listdir('./tf-saves/')
# returns:
['checkpoint',
 'mnist-20171207-23.05.18.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.18.ckpt.index',
 'mnist-20171207-23.05.18.ckpt.meta',
 'mnist-20171207-23.05.22.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.22.ckpt.index',
 'mnist-20171207-23.05.22.ckpt.meta',
 'mnist-20171207-23.05.26.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.26.ckpt.index',
 'mnist-20171207-23.05.26.ckpt.meta',
 'mnist-20171207-23.05.31.ckpt.data-00000-of-00001',
 'mnist-20171207-23.05.31.ckpt.index',
 '_lastest.ckpt.data-00000-of-00001',
 '_lastest.ckpt.index',
 '_lastest.ckpt.meta']
like image 160
James Avatar answered Oct 11 '22 03:10

James