How to use Tensorflow's tf.cond() with two different Dataset iterators without iterating both?



I want to feed a CNN with the tensor "images". I want this tensor to contain images from the training set ( which have FIXED size ) when the placeholder is_training is True, otherwise I want it to contain images from the test set ( which are of NOT FIXED size ).

This is needed because in training I take a random fixed crop from the training images, while in test I want to perform a dense evaluation and feed the entire images inside the network ( it is fully convolutional so it will accept them)

The current NOT WORKING way is to create two different iterators, and try to select the training/test input with tf.cond at the session.run(images,{is_training:True/False}).

The problem is that BOTH the iterators are evaluated. The training and test dataset are also of different size so I cannot iterate both of them until the end. Is there a way to make this work? Or to rewrite this in a smarter way?

I've seen some questions/answers about this but they always used tf.assign which takes a numpy array and assigns it to a tensor. In this case I cannot use tf.assign because I already have a tensor coming from the iterators.

The current code that I have is this one. It simply checks the shape of the tensor "images":

train_filenames, train_labels = list_images(args.train_dir)
val_filenames, val_labels = list_images(args.val_dir)

graph = tf.Graph()
with graph.as_default():

    # Preprocessing (for both training and validation):
    def _parse_function(filename, label):
        image_string = tf.read_file(filename)
        image_decoded = tf.image.decode_jpeg(image_string, channels=3)          
        image = tf.cast(image_decoded, tf.float32)

        return image, label

    # Preprocessing (for training)
    def training_preprocess(image, label):

        # Random flip and crop
        image = tf.image.random_flip_left_right(image)
        image = tf.random_crop(image, [args.crop,args.crop, 3])

        return image, label

    # Preprocessing (for validation)
    def val_preprocess(image, label):

        flipped_image = tf.image.flip_left_right(image)
        batch = tf.stack([image,flipped_image],axis=0)

        return batch, label

    # Training dataset
    train_filenames = tf.constant(train_filenames)
    train_labels = tf.constant(train_labels)
    train_dataset = tf.contrib.data.Dataset.from_tensor_slices((train_filenames, train_labels))
    train_dataset = train_dataset.map(_parse_function,num_threads=args.num_workers, output_buffer_size=args.batch_size)
    train_dataset = train_dataset.map(training_preprocess,num_threads=args.num_workers, output_buffer_size=args.batch_size)
    train_dataset = train_dataset.shuffle(buffer_size=10000) 
    batched_train_dataset = train_dataset.batch(args.batch_size)

    # Validation dataset
    val_filenames = tf.constant(val_filenames)
    val_labels = tf.constant(val_labels)
    val_dataset = tf.contrib.data.Dataset.from_tensor_slices((val_filenames, val_labels))
    val_dataset = val_dataset.map(_parse_function,num_threads=1, output_buffer_size=1)
    val_dataset = val_dataset.map(val_preprocess,num_threads=1, output_buffer_size=1)

    train_iterator = tf.contrib.data.Iterator.from_structure(batched_train_dataset.output_types,batched_train_dataset.output_shapes)
    val_iterator = tf.contrib.data.Iterator.from_structure(val_dataset.output_types,val_dataset.output_shapes)

    train_images, train_labels = train_iterator.get_next()
    val_images, val_labels = val_iterator.get_next()

    train_init_op = train_iterator.make_initializer(batched_train_dataset)
    val_init_op = val_iterator.make_initializer(val_dataset)

    # Indicates whether we are in training or in test mode
    is_training = tf.placeholder(tf.bool)

    def f_true():
        with tf.control_dependencies([tf.identity(train_images)]):
            return tf.identity(train_images)

    def f_false():
        return val_images

    images = tf.cond(is_training,f_true,f_false)

    num_images = images.shape

    with tf.Session(graph=graph) as sess:


        img = sess.run(images,{is_training:True})

The problem is that when I want to use only the training iterator, I comment the line to initialize the val_init_op but there is the following error:

FailedPreconditionError (see above for traceback): GetNext() failed because the iterator has not been initialized. Ensure that you have run the initializer operation for this iterator before getting the next element.
 [[Node: IteratorGetNext_1 = IteratorGetNext[output_shapes=[[2,?,?,3], []], output_types=[DT_FLOAT, DT_INT32], _device="/job:localhost/replica:0/task:0/cpu:0"](Iterator_1)]]

If I do not comment that line everything works as expected, when is_training is true I get training images and when is_training is False I get validation images. The issue is that both the iterators need to be initialized and when I evaluate one of them, the other is incremented too. Since as I said they are of different size this causes an issue.

I hope there is a way to solve it! Thanks in advance

1 Answers

The trick is to call iterator.get_next() inside the f_true() and f_false() functions:

def f_true():
    train_images, _ = train_iterator.get_next()
    return train_images

def f_false():
    val_images, _ = val_iterator.get_next()
    return val_images

images = tf.cond(is_training, f_true, f_false)

The same advice applies to any TensorFlow op that has a side effect, like assigning to a variable: if you want that side effect to happen conditionally, the op must be created inside the appropriate branch function passed to tf.cond().

