Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow how to safely terminate training manually (KeyboardInterrupt)

I wish to add functionality to my code such that if I desire to terminate the code at any point it will safely terminate training and save the variables. Although I've tried searching for a better solution, I think catching a KeyboardInterrupt exception would be my best bet.

Would it, however, be safe? More specifically, would the following code work:

with tf.Session() as sess    
    try:
        for i in range(FLAGS.max_steps):
            sess.run(train_op, feed_dict=some_feed_dictionary)
            # Some other summary writing and evaluative operations
    except KeyboardInterrupt:
        print("Manual interrupt occurred.")

    print('Done training for {} steps'.format(global_steps))
    save_path = saver.save(sess, 'Standard CNN', global_step=global_steps, write_meta_graph=False)

Or is it unsafe and can result in corrupted save files considering the Keyboard Interrupt is free to occur in the middle of any tensorflow operation? Is there an adequate way of doing this?

like image 990
Bruno KM Avatar asked Nov 08 '22 19:11

Bruno KM


1 Answers

I personally use something very similar to this by catching KeyboardInterrupt during training all the time, the only difference is I "save" after each sess.run step (or every several of those steps), never had problems with that.

I do not know the answer to "is it unsafe" but I know that my method worked to avoid even asking that question...

In your code that would look like this:

with tf.Session() as sess    
    try:
        for i in range(FLAGS.max_steps):
            sess.run(train_op, feed_dict=some_feed_dictionary)
            # Some other summary writing and evaluative operations
            if i % save_steps == 0:
                save_path = saver.save(sess, 'Standard CNN', global_step=global_steps, write_meta_graph=False)
    except KeyboardInterrupt:
        print("Manual interrupt occurred.")
        print('Done training for {} steps'.format(global_steps))

To clarify, the save_steps variable determines how many steps between saves.

like image 174
Ofer Sadan Avatar answered Nov 15 '22 06:11

Ofer Sadan