I am using a tf.train.string_input_producer with an epoch limit to feed data into my model. How can I get the current epoch of this op during training?
I noticed there are some nodes in the graph related to this op, one of which contains the epoch limit, but I can't find where the actual current value is being stored. Surely this is being tracked somewhere?
More generally, how can I monitor the current epoch in a TFRecords pipeline?
I was unable to find this anywhere in TF.
My solution was to do it manually, by batching on (infinite) repeat, and just calling my nodes as often as I wanted (determined in advance by calculating the number of items in the dataset, divided by the batch size = one epoch).
This was made easier in the recent release of TF using tensorflow.contrib.data.TFRecordDataset:
d = TFRecordDataset('some_filename.tfrecords')
d = d.map(function_which_parses_your_protobuf_format)
d = d.repeat()
d = d.shuffle()
d = d.batch(batch_size)
You can then determine the size of your dataset using
record_count = sum([1 for r in tf.python_io.tf_record_iteration('your_filename.tfrecord')])
It seems like more work, but it provides better flexibility because you can, e.g., use caching, so you don't have to preprocess your dataset in advance and can thus store the original untouched dataset in a tfrecord file.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With