While adding the .cache() step to my dataset pipeline, successives training epochs still download the data from the network storage.
I have a dataset on a network storage. I want to cache it, but not to repeat it: a training epoch must run through the whole dataset. Here is my dataset building pipeline:
return tf.data.Dataset.list_files(
file_pattern
).interleave(
tf.data.TFRecordDataset,
num_parallel_calls=tf.data.experimental.AUTOTUNE
).shuffle(
buffer_size=2048
).batch(
batch_size=2048,
drop_remainder=True,
).cache(
).map(
map_func=_parse_example_batch,
num_parallel_calls=tf.data.experimental.AUTOTUNE
).prefetch(
buffer_size=32
)
If I use it as is, the dataset is downloaded at each epoch. To avoid this, I have to add the .repeat() step to the pipeline and use the steps_per_epoch keyword of the model.fit function. However, I do not know the size of the complete dataset and thus I cannot pass the right steps_per_epoch value.
What is the right way to cache and use dataset of unknown size?
Thanks.
While reading some TF code, I (re)discovered the make_initializable_iterator. It seems that it is what I am looking for, that is to say iterate multiple times through the same dataset (taking advantage of the cache after the first iteration). However, this is deprecated and no longer part of the main API in TF2.
Updating instruction is to manually iterate over the Dataset with for ... in dataset. Is it not what is done by the keras.Model.fit function? Have I to write the training loop manually to get cache advantages?
Kind.
In TF2.0, you do not need .repeat(). By
successives training epochs still download the data from the network storage.
I think you got confused with the message filling up shuffle buffer. This happens before every epoch if you are using shuffle() function. Maybe try without shuffle(), just to see the difference.
Also, I would suggest you to use cache() after map() and before batch().
EDIT
filling up shuffle buffer
is a message you get when using shuffle function. You can still shuffle() the dataset after using cache(). Look here
Also, if I understood it correctly you are feeding the resulted dataset from map() to your model for training, then you should cache() this dataset not the other one because training will be done on this.
For counting the number of elements in your dataset you can use following code
num_elements = 0
for element in dataset: # tf.dataset type
num_elements += 1
print ('Total number of elements in the file: ',num_elements)
Now, by diving this num_elements with your batch_size you would get steps_per_epoch
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With