Starting from the Tensorflow CNN example, I'm trying to modify the model to have multiple images as an input (so that the input has not just 3 input channels, but multiples of 3 by stacking images). To augment the input, I try to use random image operations, such as flipping, contrast and brightness provided in TensorFlow. My current solution to apply the same random distortion to all input images is to use a fixed seed value for these operations:
def distort_image(image):
flipped_image = tf.image.random_flip_left_right(image, seed=42)
contrast_image = tf.image.random_contrast(flipped_image, lower=0.2, upper=1.8, seed=43)
brightness_image = tf.image.random_brightness(contrast_image, max_delta=0.2, seed=44)
return brightness_image
This method is called multiple times for each image at graph construction time, so I thought for each image it will use the same random number sequence and consequently, it will result in have the same applied image operations for my image input sequence.
# ...
# distort images
distorted_prediction = distort_image(seq_record.prediction)
distorted_input = []
for i in xrange(INPUT_SEQ_LENGTH):
distorted_input.append(distort_image(seq_record.input[i,:,:,:]))
stacked_distorted_input = tf.concat(2, distorted_input)
# Ensure that the random shuffling has good mixing properties.
min_queue_examples = int(num_examples_per_epoch *
MIN_FRACTION_EXAMPLES_IN_QUEUE)
# Generate a batch of sequences and prediction by building up a queue of examples.
return generate_sequence_batch(stacked_distorted_input, distorted_prediction, min_queue_examples,
batch_size, shuffle=True)
In theory, this works fine. And after doing some test runs, this really seemed to solve my problem. But after a while, I found out that I'm having a race-condition, because I use the input pipeline of the CNN-example code with multiple threads (which is the suggested method in TensorFlow to improve performance and reduce memory consumption at runtime):
def generate_sequence_batch(sequence_in, prediction, min_queue_examples,
batch_size):
num_preprocess_threads = 8 # <-- !!!
sequence_batch, prediction_batch = tf.train.shuffle_batch(
[sequence_in, prediction],
batch_size=batch_size,
num_threads=num_preprocess_threads,
capacity=min_queue_examples + 3 * batch_size,
min_after_dequeue=min_queue_examples)
return sequence_batch, prediction_batch
Because multiple threads create my examples, it is not guaranteed anymore that all image operations are performed in the right order (in sense of the right order of random operations).
Here I came to a point where I got completely stuck. Does anyone know how to solve this problem to apply the same image distortion to multiple images?
Some thoughts of mine:
Here is what I came up with by looking at the code of random_flip_up_down and random_flip_left_right within tensorflow :
def image_distortions(image, distortions):
distort_left_right_random = distortions[0]
mirror = tf.less(tf.pack([1.0, distort_left_right_random, 1.0]), 0.5)
image = tf.reverse(image, mirror)
distort_up_down_random = distortions[1]
mirror = tf.less(tf.pack([distort_up_down_random, 1.0, 1.0]), 0.5)
image = tf.reverse(image, mirror)
return image
distortions = tf.random_uniform([2], 0, 1.0, dtype=tf.float32)
image = image_distortions(image, distortions)
label = image_distortions(label, distortions)
I would do something like this using tf.case
. It allows you to specify what to return if certain condition holds https://www.tensorflow.org/api_docs/python/tf/case
import tensorflow as tf
def distort(image, x):
# flip vertically, horizontally, both, or do nothing
image = tf.case({
tf.equal(x,0): lambda: tf.reverse(image,[0]),
tf.equal(x,1): lambda: tf.reverse(image,[1]),
tf.equal(x,2): lambda: tf.reverse(image,[0,1]),
}, default=lambda: image, exclusive=True)
return image
def random_distortion(image):
x = tf.random_uniform([1], 0, 4, dtype=tf.int32)
return distort(image, x[0])
To check if it works.
import numpy as np
import matplotlib.pyplot as plt
# create image
image = np.zeros((25,25))
image[:10,5:10] = 1.
# create subplots
fig, axes = plt.subplots(2,2)
for i in axes.flatten(): i.axis('off')
with tf.Session() as sess:
for i in range(4):
distorted_img = sess.run(distort(image, i))
axes[i % 2][i // 2].imshow(distorted_img, cmap='gray')
plt.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With