Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to cache data during the first epoch correctly (Tensorflow, dataset)?

I'm trying to used the cache transformation for a dataset. Here is my current code (simplified):

dataset = tf.data.TFRecordDataset(filenames, num_parallel_reads=1)
dataset = dataset.apply(tf.contrib.data.shuffle_and_repeat(buffer_size=5000, count=1))
dataset = dataset.map(_parser_a, num_parallel_calls=12)
dataset = dataset.padded_batch(
    20, 
    padded_shapes=padded_shapes,
    padding_values=padding_values
)
dataset = dataset.prefetch(buffer_size=1)
dataset = dataset.cache()

After the first epoch, I received the following error message:

The calling iterator did not fully read the dataset we were attempting to cache. In order to avoid unexpected truncation of the sequence, the current [partially cached] sequence will be dropped. This can occur if you have a sequence similar to dataset.cache().take(k).repeat(). Instead, swap the order (i.e. dataset.take(k).cache().repeat())

Then, the code proceeded and still read data from the hard drive instead of the cache. So, where should I place dataset.cache() to avoid the error? Thanks.

like image 934
Maosi Chen Avatar asked May 24 '18 23:05

Maosi Chen


People also ask

What does dataset cache do?

Dataset. cache transformation can cache a dataset, either in memory or on local storage. This will save some operations (like file opening and data reading) from being executed during each epoch.

How can TensorFlow be used to configure the dataset for performance?

Tensorflow and pre-trained model can be used to configure the dataset for performance using the 'AUTOTUNE' attribute that is present in the 'tf. Data' module. Buffered prefetching is used to ensure that the data can be taken from disk without having I/O become blocking.

How many batches are in TensorFlow dataset?

By default, the batch size (batch_size) is 32.

What is Prefetchdataset?

Creates a dataset that asynchronously prefetches elements from input_dataset .


1 Answers

The implementation of the Dataset.cache() transformation is fairly simple: it builds up a list of the elements that pass through it as you iterate over completely it the first time, and it returns elements from that list on subsequent attempts to iterate over it. If the first pass only performs a partial pass over the data then the list is incomplete, and TensorFlow doesn't try to use the cached data, because it doesn't know whether the remaining elements will be needed, and in general it might need to reprocess all the preceding elements to compute the remaining elements.

By modifying your program to consume the entire dataset, and iterate over it until tf.errors.OutOfRangeError is raised, the cache will have a complete list of the elements in the dataset, and it will be used on all subsequent iterations.

like image 190
mrry Avatar answered Oct 06 '22 11:10

mrry