I ran a training job with tensorflow and got the following curve for loss on validation set. The net starts to overfit after 6000-th iteration. So I'd like to get the model before overfitting.
My training code is something like below:
train_step = ......
summary = tf.scalar_summary(l1_loss.op.name, l1_loss)
summary_writer = tf.train.SummaryWriter("checkpoint", sess.graph)
saver = tf.train.Saver()
for i in xrange(20000):
batch = get_next_batch(batch_size)
sess.run(train_step, feed_dict = {x: batch.x, y:batch.y})
if (i+1) % 100 == 0:
saver.save(sess, "checkpoint/net", global_step = i+1)
summary_str = sess.run(summary, feed_dict=validation_feed_dict)
summary_writer.add_summary(summary_str, i+1)
summary_writer.flush()
After training finishes, there is only five checkpoints saved (19600, 19700, 19800, 19900, 20000). Is there any way to let tensorflow save checkpoint according to the validation error?
P.S. I know that tf.train.Saver
has a max_to_keep
argument, which in principal could save all the checkpoints. But that's not I wanted (unless it's the only option). I want the saver keep the checkpoint with the smallest validation loss so far. Is that possible?
if save_best_only=True , it only saves when the model is considered the "best" and the latest best model according to the quantity monitored will not be overwritten. If filepath doesn't contain formatting options like {epoch} then filepath will be overwritten by each new better model.
Solutions to this are to decrease your network size, or to increase dropout. For example you could try dropout of 0.5 and so on. If your training/validation loss are about equal then your model is underfitting. Increase the size of your model (either number of layers or the raw number of neurons per layer)
If you want to save the best model during training, you have to use the ModelCheckpoint callback class. It has options to save the model weights at given times during the training and will allow you to keep the weights of the model at the end of the epoch specifically where the validation loss was at its minimum.
The ModelCheckpoint callback class allows you to define where to checkpoint the model weights, how to name the file, and under what circumstances to make a checkpoint of the model. The API allows you to specify which metric to monitor, such as loss or accuracy on the training or validation dataset.
You need to calculate the classification accuracy on the validation-set and keep track of the best one seen so far, and only write the checkpoint once an improvement has been found to the validation accuracy.
If the data-set and/or model is large, then you may have to split the validation-set into batches to fit the computation in memory.
This tutorial shows exactly how to do what you want:
https://github.com/Hvass-Labs/TensorFlow-Tutorials/blob/master/04_Save_Restore.ipynb
It is also available as a short video:
https://www.youtube.com/watch?v=Lx8JUJROkh0
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With