Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow, how to concatenate multiple datasets with varying batch sizes

Tags:

Imagine I have:

  • dataset 1 with data [5, 5, 5, 5, 5]
  • dataset 2 with data [4, 4]

I want to have take batches from both datasets and concatenate them so that I get batches of size 3 where:

  • I read dataset 1 with batch size 2
  • I read dataset 2 with batch size 1.

I also want to read the final batch if some datasets get emptied first. In this instance, I would get [5, 5, 4], [5, 5, 4], [5] as my final result.

How can I do this? I've seen the answer here: Tensorflow how to generate unbalanced combined data sets

It is a good try, but it does not work if one of the datasets gets emptied before the others (because then tf.errors.OutOfRangeError gets outputted pre-emptively when you try to fetch elements from the dataset that gets emptied first and I do not get the final batch). Therefore I only get [5, 5, 4], [5, 5, 4]

I thought of using tf.contrib.data.choose_from_datasets:

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5]).batch(2)
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4]).batch(1)
choice_dataset = [1, 2, 1, 2, 1]
ds = tf.contrib.data.choose_from_datasets([ds1, ds2], choice_dataset)
ds = ds.apply(tf.contrib.data.unbatch())
ds = ds.batch(3, drop_remainder=False)

This kind of works but is rather inelegant (there is unbatch and batch); also, I don't really have a great control over exactly what goes into a batch. (for instance if ds1 was [7] * 7 with batch size 2, and ds2 was [2, 2] with batch size 1, I would get [7, 7, 1], [7, 7, 1], [7, 7, 7]. But what if I actually want to have [7, 7, 1], [7, 7, 1], [7, 7], [7]? i.e. keep the number of elements from each dataset fixed.

Is there another better solution?

Another idea I had was to try to use tf.data.Dataset.flat_map:

ds1 = tf.data.Dataset.from_tensor_slices([5, 5, 5, 5, 5])
ds2 = tf.data.Dataset.from_tensor_slices([4, 4, 4, 4])
batch_sizes = [2, 1]
def concat(*inputs):
  concat = partial(functools.reduce, lambda x, y: x.concatenate(y))
  datasets = [tf.data.Dataset.from_tensors(input) for input in inputs]
  datasets = [dataset.batch(batch_size) for batch_size, dataset in zip(batch_sizes, datasets)]
  return concat(datasets)
dataset = (tf.data.Dataset
           .zip((ds1, ds2))
           .flat_map(_concat_and_batch)
           .batch(sum(batch_sizes)))

but it does not seem to work..

like image 489
cbournho Avatar asked Oct 19 '18 17:10

cbournho


1 Answers

If you don't mind running a session during the construction of the new dataset, you can do the following:

import tensorflow as tf
import numpy as np

ds1 = tf.data.Dataset.from_tensor_slices([5,5,5,5,5])
ds2 = tf.data.Dataset.from_tensor_slices([4,4])

ds1 = ds1.batch(2)
ds2 = ds2.batch(1)

iter1 = ds1.make_one_shot_iterator()
iter2 = ds2.make_one_shot_iterator()

batch1 = iter1.get_next()
batch2 = iter2.get_next()

sess = tf.Session()

# define a generator that will sess.run both datasets, and will return the concatenation of both
def GetBatch():
    while True:
        try:
            b1 = sess.run(batch1)
        except tf.errors.OutOfRangeError:
            b1 = None
        try:
            b2 = sess.run(batch2)
        except tf.errors.OutOfRangeError:
            b2 = None
        if (b1 is None) and (b2 is None):
            break
        elif b1 is None:
            yield b2
        elif b2 is None:
            yield b1
        else:
            yield np.concatenate((b1,b2))

# create a dataset from the above generator
ds = tf.data.Dataset.from_generator(GetBatch,tf.int32)

Notice that the above session can be hidden\encapsulated if you wish (for example, inside a function), for example:

iter = ds.make_one_shot_iterator()
batch = iter.get_next()

sess2 = tf.Session()

while True:
    print(sess2.run(batch))
like image 96
Lior Avatar answered Sep 17 '22 11:09

Lior