Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to shuffle a Concatenated Tensorflow dataset

I have multiple tensorflow datasets that have the same structure. I want to combine them to a single dataset. using tf.dataset.concatenate

but i found that when shuffling this combined dataset, the dataset is not shuffled on the scale of whole datasets. But shuffled in each separated dataset.

Is there any method to solve this?

like image 254
Justin Chan Avatar asked Aug 09 '18 10:08

Justin Chan


2 Answers

When you concatenate two Datasets, you get the elements of the first then the elements of the second. If you shuffle the result, you will not get a good mix if your shuffling buffer is smaller than the size of your Dataset.

What you need instead is to interleave samples from your dataset. The best way if you are using TF >= 1.9 is to use the dedicated tf.contrib.data.choose_from_datasets function. An example straight from the docs:

datasets = [tf.data.Dataset.from_tensors("foo").repeat(),
            tf.data.Dataset.from_tensors("bar").repeat(),
            tf.data.Dataset.from_tensors("baz").repeat()]

# Define a dataset containing `[0, 1, 2, 0, 1, 2, 0, 1, 2]`.
choice_dataset = tf.data.Dataset.range(3).repeat(3)

result = tf.contrib.data.choose_from_datasets(datasets, choice_dataset)

It is probably better to shuffle the input datasets if preserving the sample order and/or their ratios in a batch is important.

If you are using an earlier version of TF, you could rely on a combination of zip, flat_map and concatenate like this:

a = tf.data.Dataset.range(3).repeat()
b = tf.data.Dataset.range(100, 105).repeat()

value = (tf.data.Dataset
  .zip((a, b))
  .flat_map(lambda x, y: tf.data.Dataset.concatenate(
    tf.data.Dataset.from_tensors([x]),
    tf.data.Dataset.from_tensors([y])))
  .make_one_shot_iterator()
  .get_next())

sess = tf.InteractiveSession()

for _ in range(10):
  print(value.eval())
like image 143
P-Gn Avatar answered Sep 25 '22 19:09

P-Gn


Starting from tensorflow 1.9 you can also make use of the sample_from_datasets method.

For example, the following code

datasets = [tf.data.Dataset.from_tensors("foo").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
        tf.data.Dataset.from_tensors("bar").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat(),
        tf.data.Dataset.from_tensors("baz").repeat(3).apply(tf.data.experimental.enumerate_dataset()).repeat()]

dataset = tf.data.experimental.sample_from_datasets(datasets) # from 1.12
# dataset = tf.contrib.data.sample_from_datasets(datasets) # between 1.9 and 1.12

iterator = dataset.make_one_shot_iterator();next_element = iterator.get_next()

with tf.Session() as sess:
    for i in range(10):
        print(sess.run(next_element))

will print

(0, b'bar')
(0, b'foo')
(1, b'bar')
(0, b'baz')
(2, b'bar')
(1, b'foo')
(1, b'baz')
(2, b'foo')
(2, b'baz')
(0, b'foo')
like image 36
pfm Avatar answered Sep 21 '22 19:09

pfm