Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to speed up batch preparation when using Estimators API combined with tf.data.Dataset

I'd like to speed up my training routine that uses the Estimator API with input_fn wrote using tf.data.Dataset.

My implementation takes 2 second to prepare a batch of data and then runs training on GPU for 1 sec, and then start over preparing a batch. Which is really inefficient.

I'm looking for a way to prepare the batches asynchronously and upload them to GPU to speed up the training. Or alternatively for a way to cache datasets between invocations of input_fn (the dataset.cache() doesn't seems to be a good choice as the dataset has to be recreated on each input_fn invocation).

Here is a simplified version of my code:

def input_fn(filenames, labels, epochs):
  dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
  dataset = dataset.map(_read_wav, num_parallel_calls=num_map_threads)
  if shuffle:
     dataset = dataset.shuffle(buffer_size=len(labels))
  dataset = dataset.map(_post_process,  num_parallel_calls=num_map_threads)
  dataset = dataset.map(lambda wav, label: ({'wav': wav}, label))
  dataset = dataset.batch(128)
  dataset = dataset.repeat(epochs) # to iterate over the training set forever
  iterator = dataset.dataset.make_one_shot_iterator()
  features, labels = iterator.get_next()
  return features, labels

train_input_fn = lambda : input_fn(train_files, train_labels, None)
eval_input_fn = lambda : input_fn(eval_files, eval_labels, 1)

train_spec = tf.estimator.TrainSpec(input_fn=train_input_fn, max_steps=45000)
eval_spec = tf.estimator.EvalSpec(input_fn=eval_input_fn) 
tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)

I've noticed that the Estimator API is under active development and in the master branch of tensorflow the input_fn can return datasets already, so maybe I'm asking too early and this feature isn't ready yet. But if so, please provide a ticket where this implementation can be tracked.

like image 392
Piotr Czapla Avatar asked Dec 13 '22 19:12

Piotr Czapla


1 Answers

Using tf.data.Dataset.cache() is indeed not a good choice since it will cache the whole dataset into memory, which takes time and might overflow your memory.

The way to go is to use tf.data.Dataset.prefetch() at the end of your pipeline, which will always make sure that the data pipeline holds buffer_size elements. It is usually enough to have buffer_size = 1 at the end:

dataset = ...
dataset = dataset.batch(128)
dataset = dataset.prefetch(1)  # prefetch one batch

As explained by @mrry in this answer, you can also try to increase the number of prefetched batches a bit.

Typically it is most useful to add a small prefetch buffer (with perhaps just a single element) at the very end of the pipeline, but more complex pipelines can benefit from additional prefetching, especially when the time to produce a single element can vary.


If you still have a slow input pipeline compared to your GPU computations, you need to increase the number of threads working in parallel using the num_parallel_calls argument of tf.data.Dataset.map().

like image 94
Olivier Moindrot Avatar answered May 16 '23 08:05

Olivier Moindrot