Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to configure tensorflow legacy/train.py model.cpk output interval

I am trying to address an issue caused by overfitting of a model. Unfortunately I don't know how to increase the interval of model.cpk that legacy/train.py outputs during training. Is there a way to reduce the time between each saving of model.cpk and to disable its deletion. I am training small models and can afford the increased storage requirement.

like image 443
Artur Müller Romanov Avatar asked Jan 16 '19 08:01

Artur Müller Romanov


1 Answers

For save intervals and number of checkpoints to keep, have a look here: https://www.tensorflow.org/api_docs/python/tf/train/Saver

From the link above
-> max_to_keep
-> keep_checkpoint_every_n_hours

Additionally, optional arguments to the Saver() constructor let you control the proliferation of checkpoint files on disk:

max_to_keep indicates the maximum number of recent checkpoint files to keep. As new files are created, older files are deleted. If None or 0, no checkpoints are deleted from the filesystem but only the last one is kept in the checkpoint file. Defaults to 5 (that is, the 5 most recent checkpoint files are kept.)

keep_checkpoint_every_n_hours: In addition to keeping the most recent max_to_keep checkpoint files, you might want to keep one checkpoint file for every N hours of training. This can be useful if you want to later analyze how a model progressed during a long training session. For example, passing keep_checkpoint_every_n_hours=2 ensures that you keep one checkpoint file for every 2 hours of training. The default value of 10,000 hours effectively disables the feature.

I believe that you can reference this in the training config if you use one. Checkout the trainer.py file in the same legacy directory. Around line 375, it references keep_checkpoint_every_n_hours ->

# Save checkpoints regularly.
keep_checkpoint_every_n_hours = train_config.keep_checkpoint_every_n_hours
saver = tf.train.Saver(keep_checkpoint_every_n_hours=keep_checkpoint_every_n_hours)

What it doesn't reference is the max_to_keep line which may need to be added to that script. That said, in closing, while it's difficult to be certain without all the information, but I cannot help but to think you are going about this the wrong way. Collecting every checkpoint and reviewing doesn't seem to be the right way to deal with over fitting. Run tensorboard and check the results of your training there. Additionally doing some evaluation using the model with evaluation data will also provide a great deal of insight into what your model is doing.

All the best with your training!

like image 145
IamSierraCharlie Avatar answered Oct 21 '22 04:10

IamSierraCharlie