I am building a preprocessing and data augmentation pipeline for my image segmentation dataset There is a powerful API from keras to do this but I ran into the problem of reproducing same augmentation on image as well as segmentation mask (2nd image). Both images must undergo the exact same manipulations. Is this not supported yet?
https://www.tensorflow.org/tutorials/images/data_augmentation
Example / Pseudocode
data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])
(train_ds, test_ds), info = tfds.load('somedataset', split=['train[:80%]', 'train[80%:]'], with_info=True)
This code does not work but illustrates how my dream api would work:
train_ds = train_ds.map(lambda datapoint: data_augmentation((datapoint['image'], datapoint['segmentation_mask']), training=True))
Alternative
The alternative is to code a custom load and manipulation / randomization method as is proposed in the image segmentation tutorial (https://www.tensorflow.org/tutorials/images/segmentation)
Any tips on state of the art data augmentation for this type of dataset is much appreciated :)
Here is my own implementation in case someone else wants to use tf built-ins (tf.image api) as of decembre 2020 :)
@tf.function
def load_image(datapoint, augment=True):
# resize image and mask
img_orig = input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))
mask_orig = input_mask = tf.image.resize(datapoint['segmentation_mask'], (IMG_SIZE, IMG_SIZE))
# rescale the image
if IMAGE_CHANNELS == 1:
input_image = tf.image.rgb_to_grayscale(input_image)
input_image = tf.cast(input_image, tf.float32) / 255.0
# augmentation
if augment:
# zoom in a bit
if tf.random.uniform(()) > 0.5:
# use original image to preserve high resolution
input_image = tf.image.central_crop(img_orig, 0.75)
input_mask = tf.image.central_crop(mask_orig, 0.75)
# resize
input_image = tf.image.resize(input_image, (IMG_SIZE, IMG_SIZE))
input_mask = tf.image.resize(input_mask, (IMG_SIZE, IMG_SIZE))
# random brightness adjustment illumination
input_image = tf.image.random_brightness(input_image, 0.3)
# random contrast adjustment
input_image = tf.image.random_contrast(input_image, 0.2, 0.5)
# flipping random horizontal or vertical
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_left_right(input_image)
input_mask = tf.image.flip_left_right(input_mask)
if tf.random.uniform(()) > 0.5:
input_image = tf.image.flip_up_down(input_image)
input_mask = tf.image.flip_up_down(input_mask)
# rotation in 30° steps
rot_factor = tf.cast(tf.random.uniform(shape=[], maxval=12, dtype=tf.int32), tf.float32)
angle = np.pi/12*rot_factor
input_image = tfa.image.rotate(input_image, angle)
input_mask = tfa.image.rotate(input_mask, angle)
return input_image, input_mask
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With