I'm changing my TensorFlow code from the old queue interface to the new Dataset API. With the old interface I could specify the num_threads
argument to the tf.train.shuffle_batch
queue. However, the only way to control the amount of threads in the Dataset API seems to be in the map
function using the num_parallel_calls
argument. However, I'm using the flat_map
function instead, which doesn't have such an argument.
Question: Is there a way to control the number of threads/processes for the flat_map
function? Or is there are way to use map
in combination with flat_map
and still specify the number of parallel calls?
Note that it is of crucial importance to run multiple threads in parallel, as I intend to run heavy pre-processing on the CPU before data enters the queue.
There are two (here and here) related posts on GitHub, but I don't think they answer this question.
Here is a minimal code example of my use-case for illustration:
with tf.Graph().as_default():
data = tf.ones(shape=(10, 512), dtype=tf.float32, name="data")
input_tensors = (data,)
def pre_processing_func(data_):
# normally I would do data-augmentation here
results = (tf.expand_dims(data_, axis=0),)
return tf.data.Dataset.from_tensor_slices(results)
dataset_source = tf.data.Dataset.from_tensor_slices(input_tensors)
dataset = dataset_source.flat_map(pre_processing_func)
# do something with 'dataset'
To iterate over the dataset several times, use . repeat() . We can enumerate each batch by using either Python's enumerator or a build-in method.
TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.
With the help of tf. data. Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf.
To the best of my knowledge, at the moment flat_map
does not offer parallelism options.
Given that the bulk of the computation is done in pre_processing_func
, what you might use as a workaround is a parallel map
call followed by some buffering, and then using a flat_map
call with an identity lambda function that takes care of flattening the output.
In code:
NUM_THREADS = 5
BUFFER_SIZE = 1000
def pre_processing_func(data_):
# data-augmentation here
# generate new samples starting from the sample `data_`
artificial_samples = generate_from_sample(data_)
return atificial_samples
dataset_source = (tf.data.Dataset.from_tensor_slices(input_tensors).
map(pre_processing_func, num_parallel_calls=NUM_THREADS).
prefetch(BUFFER_SIZE).
flat_map(lambda *x : tf.data.Dataset.from_tensor_slices(x)).
shuffle(BUFFER_SIZE)) # my addition, probably necessary though
Since pre_processing_func
generates an arbitrary number of new samples starting from the initial sample (organised in matrices of shape (?, 512)
), the flat_map
call is necessary to turn all the generated matrices into Dataset
s containing single samples (hence the tf.data.Dataset.from_tensor_slices(x)
in the lambda) and then flatten all these datasets into one big Dataset
containing individual samples.
It's probably a good idea to .shuffle()
that dataset, or generated samples will be packed together.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With