I have multiple tensorflow datasets that have the same structure. I want to combine them to a single dataset. using tf.dataset.concatenate
but i found that when shuffling this combined dataset, the dataset is not shuffled on the scale of whole datasets. But shuffled in each separated dataset.
Is there any method to solve this?
When you concatenate two Dataset
s, you get the elements of the first then the elements of the second. If you shuffle the result, you will not get a good mix if your shuffling buffer is smaller than the size of your Dataset
.
What you need instead is to interleave samples from your dataset. The best way if you are using TF >= 1.9 is to use the dedicated tf.contrib.data.choose_from_datasets
function. An example straight from the docs:
datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
tf.data.Dataset.from_tensors("bar").repeat(),
tf.data.Dataset.from_tensors("baz").repeat()]
# Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
choice_dataset = tf.data.Dataset.range(3).repeat(3)
result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)
It is probably better to shuffle the input datasets if preserving the sample order and/or their ratios in a batch is important.
If you are using an earlier version of TF, you could rely on a combination of zip
, flat_map
and concatenate
like this:
a = tf.data.Dataset.range(3).repeat()
b = tf.data.Dataset.range(100, 105).repeat()
value = (tf.data.Dataset
.zip((a, b))
.flat_map(lambda x, y: tf.data.Dataset.concatenate(
tf.data.Dataset.from_tensors([x]),
tf.data.Dataset.from_tensors([y])))
.make_one_shot_iterator()
.get_next())
sess = tf.InteractiveSession()
for _ in range(10):
print(value.eval())
Starting from tensorflow 1.9 you can also make use of the sample_from_datasets method.
For example, the following code
datasets = [tf.data.Dataset.from_tensors("foo").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
tf.data.Dataset.from_tensors("bar").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
tf.data.Dataset.from_tensors("baz").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat()]
dataset = tf.data.experimental.sample_from_datasets(datasets) # from 1.12
# dataset = tf.contrib.data.sample_from_datasets(datasets) # between 1.9 and 1.12
iterator = dataset.make_one_shot_iterator();next_element = iterator.get_next()
with tf.Session() as sess:
for i in range(10):
print(sess.run(next_element))
will print
(0, b'bar')
(0, b'foo')
(1, b'bar')
(0, b'baz')
(2, b'bar')
(1, b'foo')
(1, b'baz')
(2, b'foo')
(2, b'baz')
(0, b'foo')
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With