Is there a way to break a tensorflow session by keyboard interrupt and have the option to save the model at that point? I currently leave the session running overnight but need to stop it so I can free memory for pc use during the day. As the training progresses each epoch gets slower so there are times when I may have to wait hours for the next scheduled save in the program. I'd like the functionality of being able to break into the run at any time and save from that point. I can't even find if that's possible. Would appreciate a pointer.
One option is to subclass the tf.Session
object and create an __exit__
function that saves out the current state as the keyboard interrupt passes through. This would only work when the new object is called as part of a with
block.
Here is the subclass:
import tensorflow as tf
class SessionWithExitSave(tf.Session):
def __init__(self, *args, saver=None, exit_save_path=None, **kwargs):
self.saver = saver
self.exit_save_path = exit_save_path
super().__init__(*args, **kwargs)
def __exit__(self, exc_type, exc_value, exc_tb):
if exc_type is KeyboardInterrupt:
if self.saver:
self.saver.save(self, self.exit_save_path)
print('Output saved to: "{}./*"'.format(self.exit_save_path))
super().__exit__(exc_type, exc_value, exc_tb)
Example usage from the TensorFlow mnist walkthrough.
import tensorflow as tf
import datetime as dt
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('U:/mnist/', one_hot=True)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
y = tf.matmul(x, W) + b
# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean(
tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y))
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(cross_entropy)
saver = tf.train.Saver()
with SessionWithExitSave(
saver=saver,
exit_save_path='./tf-saves/_lastest.ckpt') as sess:
sess.run(tf.global_variables_initializer())
total_epochs = 50
for epoch in range(1, total_epochs+1):
for _ in range(1000):
batch_xs, batch_ys = mnist.train.next_batch(100)
sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
# Test trained model
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
print(f'Epoch {epoch} of {total_epochs} :: accuracy = ', end='')
print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))
save_time = dt.datetime.now().strftime('%Y%m%d-%H.%M.%S')
saver.save(sess, f'./tf-saves/mnist-{save_time}.ckpt')
I let this run for 10 epochs before sending an interrupt signal from the keyboard. Here is the output:
Epoch 1 of 50 :: accuracy = 0.9169
Epoch 2 of 50 :: accuracy = 0.919
Epoch 3 of 50 :: accuracy = 0.9205
Epoch 4 of 50 :: accuracy = 0.9221
Epoch 5 of 50 :: accuracy = 0.92
Epoch 6 of 50 :: accuracy = 0.9229
Epoch 7 of 50 :: accuracy = 0.9234
Epoch 8 of 50 :: accuracy = 0.9234
Epoch 9 of 50 :: accuracy = 0.9252
Epoch 10 of 50 :: accuracy = 0.9248
Output saved to: "./tf-saves/_lastest.ckpt./*"
---------------------------------------------------------------------------
KeyboardInterrupt Traceback (most recent call last)
...
--> 768 elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE:
769 return item[1]._is_present_in_parent
770 else:
KeyboardInterrupt:
And indeed, I do have all of the saved files include the save from the keyboard interrupt that was sent to the system.
import os
os.listdir('./tf-saves/')
# returns:
['checkpoint',
'mnist-20171207-23.05.18.ckpt.data-00000-of-00001',
'mnist-20171207-23.05.18.ckpt.index',
'mnist-20171207-23.05.18.ckpt.meta',
'mnist-20171207-23.05.22.ckpt.data-00000-of-00001',
'mnist-20171207-23.05.22.ckpt.index',
'mnist-20171207-23.05.22.ckpt.meta',
'mnist-20171207-23.05.26.ckpt.data-00000-of-00001',
'mnist-20171207-23.05.26.ckpt.index',
'mnist-20171207-23.05.26.ckpt.meta',
'mnist-20171207-23.05.31.ckpt.data-00000-of-00001',
'mnist-20171207-23.05.31.ckpt.index',
'_lastest.ckpt.data-00000-of-00001',
'_lastest.ckpt.index',
'_lastest.ckpt.meta']
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With