Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How is data augmentation implemented in Tensorflow?

Based on the Tensorflow tutorial for ConvNet, some points are not readily apparent to me:

  • are the images being distorted actually added to the pool of original images?
  • or are the distorted images used instead of the originals?
  • how many distorted images are being produced? (i.e. what augmentation factor was defined?)

The flow of functions for the tutorial seems to be as follows:

cifar_10_train.py

def train
    """Train CIFAR-10 for a number of steps."""
    with tf.Graph().as_default():
        [...]
        # Get images and labels for CIFAR-10.
        images, labels = cifar10.distorted_inputs()
        [...]

cifar10.py

def distorted_inputs():
    """Construct distorted input for CIFAR training using the Reader ops.

    Returns:
      images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
      labels: Labels. 1D tensor of [batch_size] size.

    Raises:
      ValueError: If no data_dir
    """
    if not FLAGS.data_dir:
        raise ValueError('Please supply a data_dir')
    data_dir = os.path.join(FLAGS.data_dir, 'cifar-10-batches-bin')
    return cifar10_input.distorted_inputs(data_dir=data_dir,
                                          batch_size=FLAGS.batch_size)

and finally cifar10_input.py

def distorted_inputs(data_dir, batch_size):
    """Construct distorted input for CIFAR training using the Reader ops.

    Args:
    data_dir: Path to the CIFAR-10 data directory.
    batch_size: Number of images per batch.

    Returns:
    images: Images. 4D tensor of [batch_size, IMAGE_SIZE, IMAGE_SIZE, 3] size.
    labels: Labels. 1D tensor of [batch_size] size.
    """
    filenames = [os.path.join(data_dir, 'data_batch_%d.bin' % i) for i in xrange(1, 6)]
    for f in filenames:
        if not tf.gfile.Exists(f):
            raise ValueError('Failed to find file: ' + f)

    # Create a queue that produces the filenames to read.
    filename_queue = tf.train.string_input_producer(filenames)

    # Read examples from files in the filename queue.
    read_input = read_cifar10(filename_queue)
    reshaped_image = tf.cast(read_input.uint8image, tf.float32)

    height = IMAGE_SIZE
    width = IMAGE_SIZE

    # Image processing for training the network. Note the many random
    # distortions applied to the image.

    # Randomly crop a [height, width] section of the image.
    distorted_image = tf.random_crop(reshaped_image, [height, width, 3])

    # Randomly flip the image horizontally.
    distorted_image = tf.image.random_flip_left_right(distorted_image)

    # Because these operations are not commutative, consider randomizing
    # the order their operation.
    distorted_image = tf.image.random_brightness(distorted_image, max_delta=63)
    distorted_image = tf.image.random_contrast(distorted_image, lower=0.2, upper=1.8)

    # Subtract off the mean and divide by the variance of the pixels.
    float_image = tf.image.per_image_whitening(distorted_image)

    # Ensure that the random shuffling has good mixing properties.
    min_fraction_of_examples_in_queue = 0.4
    min_queue_examples = int(NUM_EXAMPLES_PER_EPOCH_FOR_TRAIN *
                             min_fraction_of_examples_in_queue)
    print('Filling queue with %d CIFAR images before starting to train.'
          'This will take a few minutes.' % min_queue_examples)

    # Generate a batch of images and labels by building up a queue of examples.
    return _generate_image_and_label_batch(float_image, read_input.label,
                                           min_queue_examples, batch_size,
                                           shuffle=True)
like image 996
pepe Avatar asked Dec 25 '22 05:12

pepe


1 Answers

are the images being distorted actually added to the pool of original images?

It depends on the definition of the pool. In tensorflow, you have ops which are basic objects in your network graph. here, data production is an op itself. Thus you do not have a finite set of training samples, instead you have a potentialy infinite set of samples generated from the training set.

or are the distorted images used instead of the originals?

As you can see from the source you included - sample is taken from the training batch, then it is randomly transformed, thus there is very small probability of using unaltered image (especially that cropping is used, which always modifies).

how many distorted images are being produced? (i.e. what augmentation factor was defined?)

There is no such thing, this is never ending process. Think about this in terms of random access to possibly infinite source of data, as this is what is efficiently happening here. Every single batch can be different from the previous one.

like image 102
lejlot Avatar answered Dec 26 '22 20:12

lejlot