Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Correct way of doing data augmentation in TensorFlow with the dataset api?

So, I've been playing around with the TensorFlow dataset API for loading images, and segmentation masks (for a semantic segmentation project), I would like to be able to generate batches of images and masks, with each image having randomly gone through any combination of pre-processing functions like brightness changes, contrast changes, cropping, saturation changes etc. So, the first image in my batch may have no pre-processing, second may have saturation changes, third may have brightness and saturation and so on.

I tried the following:

import tensorflow as tf
from tensorflow.contrib.data import Dataset, Iterator
import random


def _resize_image(image, mask):
    image = tf.image.resize_bicubic(image, [480, 640], True)
    mask = tf.image.resize_bicubic(mask, [480, 640], True)
    return image, mask

def _corrupt_contrast(image, mask):
    image = tf.image.random_contrast(image, 0, 5)
    return image, mask


def _corrupt_saturation(image, mask):
    image = tf.image.random_saturation(image, 0, 5)
    return image, mask


def _corrupt_brightness(image, mask):
    image = tf.image.random_brightness(image, 5)
    return image, mask


def _random_crop(image, mask):
    seed = random.random()
    image = tf.random_crop(image, [240, 320, 3], seed=seed)
    mask = tf.random_crop(mask, [240, 320, 1], seed=seed)
    return image, mask


def _flip_image_horizontally(image, mask):
    seed = random.random()
    image = tf.image.random_flip_left_right(image, seed=seed)
    mask = tf.image.random_flip_left_right(mask, seed=seed)

    return image, mask


def _flip_image_vertically(image, mask):
    seed = random.random()
    image = tf.image.random_flip_up_down(image, seed=seed)
    mask = tf.image.random_flip_up_down(mask, seed=seed)

    return image, mask


def _normalize_data(image, mask):
    image = tf.cast(image, tf.float32)
    image = image / 255.0

    mask = tf.cast(mask, tf.float32)
    mask = mask / 255.0

    return image, mask


def _parse_data(image_paths, mask_paths):
    image_content = tf.read_file(image_paths)
    mask_content = tf.read_file(mask_paths)

    images = tf.image.decode_png(image_content, channels=3)
    masks = tf.image.decode_png(mask_content, channels=1)

    return images, masks


def data_batch(image_paths, mask_paths, params, batch_size=4, num_threads=2):
    # Convert lists of paths to tensors for tensorflow
    images_name_tensor = tf.constant(image_paths)
    mask_name_tensor = tf.constant(mask_paths)

    # Create dataset out of the 2 files:
    data = Dataset.from_tensor_slices(
        (images_name_tensor, mask_name_tensor))

    # Parse images and labels
    data = data.map(
        _parse_data, num_threads=num_threads, output_buffer_size=6 * batch_size)

    # Normalize images and masks for vals. between 0 and 1
    data = data.map(_normalize_data, num_threads=num_threads, output_buffer_size=6 * batch_size)

    if params['crop'] and not random.randint(0, 1):
        data = data.map(_random_crop, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['brightness'] and not random.randint(0, 1):
        data = data.map(_corrupt_brightness, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['contrast'] and not random.randint(0, 1):
        data = data.map(_corrupt_contrast, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['saturation'] and not random.randint(0, 1):
        data = data.map(_corrupt_saturation, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    if params['flip_horizontally'] and not random.randint(0, 1):
        data = data.map(_flip_image_horizontally,
                    num_threads=num_threads, output_buffer_size=6 * batch_size)

    if params['flip_vertically'] and not random.randint(0, 1):
        data = data.map(_flip_image_vertically, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    # Shuffle the data queue
    data = data.shuffle(len(image_paths))

    # Create a batch of data
    data = data.batch(batch_size)

    data = data.map(_resize_image, num_threads=num_threads,
                    output_buffer_size=6 * batch_size)

    # Create iterator
    iterator = Iterator.from_structure(data.output_types, data.output_shapes)

    # Next element Op
    next_element = iterator.get_next()

    # Data set init. op
    init_op = iterator.make_initializer(data)

    return next_element, init_op

But all batches returned by this have the same transformations applied to them, not different combinations, my guess is that the random.randint persists, and is not actually run for each batch, if so, how do I fix this to get the desired result? For an example of how I plan to use it (I feel that's irrelevant to the problem but people might still want to know) can be found here

like image 299
Hasnain Raza Avatar asked Dec 12 '17 20:12

Hasnain Raza


People also ask

What is the TensorFlow augmentation API?

Summary: The tf.data API of Tensorflow is a great way to build a pipeline for sending data to the GPU. In this post I give a few examples of augmentations and how to implement them using this API. Wouter Bulten on 18 Jan 2019. 35 comments. Keywords: tensorflow, data, augmentation, tf.data, deep learning

How big is the image dataset in TensorFlow?

Each image is of size 150 x 150 x 3 RGB from 8 different classes, and there are 5000 images. This dataset is a part of TensorFlow datasets, which are a collection of ready-to-use datasets. First, let’s prepare the environment for TensorFlow:

Can I write my own data augmentation pipeline in TensorFlow?

But, for finer control, you can write your own data augmentation pipelines or layers using tf.data and tf.image. (You may also want to check out TensorFlow Addons Image: Operations and TensorFlow I/O: Color Space Conversions .)

What is data augmentation and how is it used?

Data augmentation is employed to show the different model variations of data that can be used in training and are consistent with the task of interest. Assume the example of image classification. Take a look at the images below: The images are extracted from the textures in the colorectal cancer histology dataset.


1 Answers

So the problem was indeed that the control flow with the if statements are with Python variables, and are only executed once when the graph is created, to do what I want to do, I had to define a placeholder that contains the boolean values of whether to apply a function or not (and feed in a new boolean tensor per iteration to change the augmentation), and control flow is handled by tf.cond. I pushed the new code to the GitHub link I posted in the question above if anyone is interested.

like image 121
Hasnain Raza Avatar answered Oct 15 '22 06:10

Hasnain Raza