interleave
is a tf.Data.Dataset
method that can be used to interleave together elements from multiple datasets. tf.contrib.data.parallel_interleave
provides a parallel version of the same functionality with the help of apply
.
I can see that reading from many datasets in parallel and having buffers for them as allowed by the parallel version will improve throughput. But the documentation also has this to say about how parallel_interleave
can increase data throughput:
Unlike tf.data.Dataset.interleave, it gets elements from cycle_length nested datasets in parallel, which increases the throughput, especially in the presence of stragglers.
What exactly are stragglers, and why does parallel_interleave
work especially well in terms of throughput in their presence?
TL;DR interleave() parallelizes the data loading step by interleaving the I/O operation to read the file. map() will apply the data pre-processing to the contents of the datasets.
Dataset. from_tensor_slices() method, we can get the slices of an array in the form of objects by using tf. data.
load will return the tuple ( tf. data. Dataset , tfds.
Dataset. prefetch transformation. It can be used to decouple the time when data is produced from the time when data is consumed.
A straggler is a function which takes longer than normal to produce its output. This can be due to congestion on the network, or weird combination of randomness.
interleave
does all the processing in a sequential manner, on a single thread. In the following schema, let ___
denote waiting for IO/Computation, <waiting>
denote waiting for its turn to spit an element and 111
denote producing the first element (1
).
Suppose we have a dataset of directories ds = [A, B, C, D]
and we produce files 1,2,3...
from each of them. Then using r = ds.interleave(cycle_length=3, block_length=2)
will work kind of like this:
A: ___111___222
B: <waiting> ___111___________222
C: <waiting> <waiting> <waiting> ___111___222
R: ____A1____A2____B1____________B2____C1____C2
You see that if producing elements from B straggles, all following elements will have to wait to be processed.
parallel_interleave
helps in two ways with stragglers. First, it starts each element in the cycle in parallel (hence the name). Therefore, the production schema becomes:
A: ___111___222
B: ___<waiting>111___________222
C: ___<waiting><waiting><waitin>111___222
R: ____A1____A2_B1____________B2_C1____C2|....|
Doing this helps with reducing useless waiting by waiting in parallel. The part |....|
shows how much we saved compared to the sequential version.
The second way it helps is by allowing a sloppy
argument. If we set it to True
, it allows skipping over an unavailable element until it is available, at the cost of producing a non-deterministic order. Here's how:
A: ___111___<w>222
B: ___<w>111___________222
C: ___<w><w>111___222
R: ____A1_B1_C1_A2_C2___B2|...................|
Look at that saving!! But also look at the order of the elements !
I reproduce these in code. It is an ugly way, but it illustrates the differences a bit.
from time import sleep
DS = tf.data.Dataset
def repeater(val):
def _slow_gen():
for i in range(5):
if i % 2:
sleep(1)
yield i
return DS.from_generator(_slow_gen, tf.int8)
ds = DS.range(5)
slow_ds = ds.interleave(repeater, cycle_length=2, block_length=3)
para_ds = ds.apply(tf.contrib.data.parallel_interleave(
repeater, cycle_length=2, block_length=3)
)
sloppy_ds = ds.apply(tf.contrib.data.parallel_interleave(
repeater, cycle_length=2, block_length=3, sloppy=True)
)
%time apply_python_func(slow_ds, print, sess)
# 10 sec, you see it waiting each time
%time apply_python_func(para_ds, print, sess)
# 3 sec always! you see it burping a lot after the first wait
%time apply_python_func(sloppy_ds, print, sess)
# sometimes 3, sometimes 4 seconds
And here's the function to show a dataset
def apply_python_func(ds, func, sess):
"""Exact values from ds using sess and apply func on them"""
it = ds.make_one_shot_iterator()
next_value = it.get_next()
num_examples = 0
while True:
try:
value = sess.run(next_value)
num_examples += 1
func(value)
except tf.errors.OutOfRangeError:
break
print('Evaluated {} examples'.format(num_examples))
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With