I'm changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads argument to the tf.train.shuffle_batch queue. However, the only way to control the amount of threads in the Dataset API seems to be in the map function using the num_parallel_calls argument. However, I'm using the flat_map function instead, which doesn't have such an argument.
Question: Is there a way to control the number of threads/processes for the flat_map function? Or is there are way to use map in combination with flat_map and still specify the number of parallel calls?
Note that it is of crucial importance to run multiple threads in parallel, as I intend to run heavy pre-processing on the CPU before data enters the queue.
There are two (here and here) related posts on GitHub, but I don't think they answer this question.
Here is a minimal code example of my use-case for illustration:
with tf.Graph().as_default():
    data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
    input_tensors = (data,)
    def pre_processing_func(data_):
        # normally I would do data-augmentation here
        results = (tf.expand_dims(data_, axis=0),)
        return tf.data.Dataset.from_tensor_slices(results)
    dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
    dataset = dataset_source.flat_map(pre_processing_func)
    # do something with 'dataset'
                To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method.
TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.
With the help of tf. data. Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf.
To the best of my knowledge, at the moment flat_map does not offer parallelism options.
Given that the bulk of the computation is done in pre_processing_func, what you might use as a workaround is a parallel map call followed by some buffering, and then using a flat_map call with an identity lambda function that takes care of flattening the output.
In code:
NUM_THREADS = 5
BUFFER_SIZE = 1000
def pre_processing_func(data_):
    # data-augmentation here
    # generate new samples starting from the sample `data_`
    artificial_samples = generate_from_sample(data_)
    return atificial_samples
dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
                  map(pre_processing_func, num_parallel_calls=NUM_THREADS).
                  prefetch(BUFFER_SIZE).
                  flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
                  shuffle(BUFFER_SIZE)) # my addition, probably necessary though
Since pre_processing_func generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)), the flat_map call is necessary to turn all the generated matrices into Datasets containing single samples (hence the tf.data.Dataset.from_tensor_slices(x) in the lambda) and then flatten all these datasets into one big Dataset containing individual samples.
It's probably a good idea to .shuffle() that dataset, or generated samples will be packed together.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With