Is it possible to delete the in-memory cache that's built after calling tf.data.Dataset.cache()
?
Here's what I'd like to do. The augmentation for the dataset is very costly, so the current code is more or less:
data = tf.data.Dataset(...) \
.map(<expensive_augmentation>) \
.cache() \
# .shuffle().batch() etc.
However, this means that every iteration over data
will see the same augmented versions of the data samples. What I'd like to do instead is to use the cache for a couple of epochs and then start over, or equivalently do something like Dataset.map(<augmentation>).fleeting_cache().repeat(8)
. Is this possible to achieve?
The cache lifecycle is tied to the dataset, so you can achieve this by re-creating the dataset:
def create_dataset():
dataset = tf.data.Dataset(...)
dataset = dataset.map(<expensive_augmentation>)
dataset = dataset.shuffle(...)
dataset = dataset.batch(...)
return dataset
for epoch in range(num_epochs):
# Drop the cache every 8 epochs.
if epoch % 8 == 0: dataset = create_dataset()
for batch in dataset:
train(batch)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With