Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.data.Dataset - delete cache?

Is it possible to delete the in-memory cache that's built after calling tf.data.Dataset.cache()?

Here's what I'd like to do. The augmentation for the dataset is very costly, so the current code is more or less:

data = tf.data.Dataset(...) \
       .map(<expensive_augmentation>) \
       .cache() \
       # .shuffle().batch() etc. 

However, this means that every iteration over data will see the same augmented versions of the data samples. What I'd like to do instead is to use the cache for a couple of epochs and then start over, or equivalently do something like Dataset.map(<augmentation>).fleeting_cache().repeat(8). Is this possible to achieve?

like image 444
cloudy Avatar asked Nov 15 '22 02:11

cloudy


1 Answers

The cache lifecycle is tied to the dataset, so you can achieve this by re-creating the dataset:

def create_dataset():
  dataset = tf.data.Dataset(...)
  dataset = dataset.map(<expensive_augmentation>)
  dataset = dataset.shuffle(...)
  dataset = dataset.batch(...)
  return dataset

for epoch in range(num_epochs):
  # Drop the cache every 8 epochs.
  if epoch % 8 == 0: dataset = create_dataset()
  for batch in dataset:
    train(batch)
like image 168
AAudibert Avatar answered Dec 11 '22 04:12

AAudibert