Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to checkpoint tf.data Dataset objects?

When checkpointing during training (in case of crash/etc) I save the graph and parameters, but it's not clear how to do the same to the new tf.data objects used for input.

Is there a straightforward way to also checkpoint these such that I can continue the current epoch, or restore the shuffle state (maybe from a seed?)

like image 759
machinaut Avatar asked Apr 09 '26 14:04

machinaut


1 Answers

The tf.contrib.data.make_saveable_from_iterator() function takes a tf.data.Iterator object and gives you back a "saveable object" that can be saved using a tf.train.Saver. It saves the entire state of the iterator, including any shuffled data.

The following example code shows how to add a simple iterator to the same checkpoint used for variables:

ds = tf.data.Dataset.range(10)
iterator = ds.make_initializable_iterator()

# [Build the training graph, using `iterator.get_next()` as the input.]

# Build the iterator SaveableObject.
saveable_obj = tf.contrib.data.make_saveable_from_iterator(iterator)

# Add the SaveableObject to the SAVEABLE_OBJECTS collection so
# it will be saved automatically using a Saver.
tf.add_to_collection(tf.GraphKeys.SAVEABLE_OBJECTS, saveable_obj)

# Create a saver that saves all objects in the `tf.GraphKeys.SAVEABLE_OBJECTS`
# collection.
saver = tf.train.Saver()

with tf.Session() as sess:
  while continue_training:

    # [Perform training.]

    if should_save_checkpoint:
      saver.save(sess, ...)

Note that the iterator checkpointing support is currently (as of TensorFlow 1.8) in an experimental state, so the checkpoint format may change from one version to the next.

like image 183
mrry Avatar answered Apr 11 '26 05:04

mrry



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!