Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to control amount of checkpoint kept by tensorflow estimator?

I've noticed that the new Estimator API automatically saves checkpoints during the training and automatically restarts from the last checkpoint when training was interrupted. Unfortunately, it seems it only keeps the last 5 checkpoints.

Do you know how to control the number of checkpoints that are kept during the training?

like image 842
Piotr Czapla Avatar asked Jan 28 '23 23:01

Piotr Czapla


1 Answers

Tensorflow tf.estimator.Estimator takes config as an optional argument, which can be a tf.estimator.RunConfig object to configure runtime settings.You can achieve this as follows:

# Change maximum number checkpoints to 25
run_config = tf.estimator.RunConfig()
run_config = run_config.replace(keep_checkpoint_max=25)

# Build your estimator
estimator = tf.estimator.Estimator(model_fn,
                                   model_dir=job_dir,
                                   config=run_config,
                                   params=None)

config parameter is available in all classes (DNNClassifier, DNNLinearCombinedClassifier, LinearClassifier, etc.) that extend estimator.Estimator.

like image 157
Zafarullah Mahmood Avatar answered Jan 31 '23 18:01

Zafarullah Mahmood