Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Parallel threads with TensorFlow Dataset API and flat_map

I'm changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads argument to the tf.train.shuffle_batch queue. However, the only way to control the amount of threads in the Dataset API seems to be in the map function using the num_parallel_calls argument. However, I'm using the flat_map function instead, which doesn't have such an argument.

Question: Is there a way to control the number of threads/processes for the flat_map function? Or is there are way to use map in combination with flat_map and still specify the number of parallel calls?

Note that it is of crucial importance to run multiple threads in parallel, as I intend to run heavy pre-processing on the CPU before data enters the queue.

There are two (here and here) related posts on GitHub, but I don't think they answer this question.

Here is a minimal code example of my use-case for illustration:

with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)

    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)

    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'
like image 483
CNugteren Avatar asked Nov 21 '17 10:11

CNugteren


People also ask

How do I iterate over a TensorFlow dataset?

To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method.

What is interleave TensorFlow?

TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.

What does TF data dataset From_tensor_slices do?

With the help of tf. data. Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf.


1 Answers

To the best of my knowledge, at the moment flat_map does not offer parallelism options. Given that the bulk of the computation is done in pre_processing_func, what you might use as a workaround is a parallel map call followed by some buffering, and then using a flat_map call with an identity lambda function that takes care of flattening the output.

In code:

NUM_THREADS = 5
BUFFER_SIZE = 1000

def pre_processing_func(data_):
    # data-augmentation here
    # generate new samples starting from the sample `data_`
    artificial_samples = generate_from_sample(data_)
    return atificial_samples

dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                  map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                  prefetch(BUFFER_SIZE).
                  flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                  shuffle(BUFFER_SIZE)) # my addition, probably necessary though

Note (to myself and whoever will try to understand the pipeline):

Since pre_processing_func generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)), the flat_map call is necessary to turn all the generated matrices into Datasets containing single samples (hence the tf.data.Dataset.from_tensor_slices(x) in the lambda) and then flatten all these datasets into one big Dataset containing individual samples.

It's probably a good idea to .shuffle() that dataset, or generated samples will be packed together.

like image 55
GPhilo Avatar answered Sep 21 '22 03:09

GPhilo