Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras data augmentation pipeline for image segmentation dataset (image and mask with same manipulation)

I am building a preprocessing and data augmentation pipeline for my image segmentation dataset There is a powerful API from keras to do this but I ran into the problem of reproducing same augmentation on image as well as segmentation mask (2nd image). Both images must undergo the exact same manipulations. Is this not supported yet?

https://www.tensorflow.org/tutorials/images/data_augmentation

Example / Pseudocode

data_augmentation = tf.keras.Sequential([
layers.experimental.preprocessing.RandomFlip(mode="horizontal_and_vertical", seed=SEED_VAL),
layers.experimental.preprocessing.RandomRotation(factor=0.4, fill_mode="constant", fill_value=0, seed=SEED_VAL),
layers.experimental.preprocessing.RandomZoom(height_factor=(-0.0,-0.2), fill_mode='constant', fill_value=0, seed=SEED_VAL)])

(train_ds, test_ds), info = tfds.load('somedataset', split=['train[:80%]', 'train[80%:]'], with_info=True)

This code does not work but illustrates how my dream api would work:

train_ds = train_ds.map(lambda datapoint: data_augmentation((datapoint['image'], datapoint['segmentation_mask']), training=True))

Alternative

The alternative is to code a custom load and manipulation / randomization method as is proposed in the image segmentation tutorial (https://www.tensorflow.org/tutorials/images/segmentation)

Any tips on state of the art data augmentation for this type of dataset is much appreciated :)

like image 759
Tom Avatar asked Oct 24 '25 19:10

Tom


1 Answers

Here is my own implementation in case someone else wants to use tf built-ins (tf.image api) as of decembre 2020 :)

@tf.function
def load_image(datapoint, augment=True):
    
    # resize image and mask
    img_orig = input_image = tf.image.resize(datapoint['image'], (IMG_SIZE, IMG_SIZE))
    mask_orig = input_mask = tf.image.resize(datapoint['segmentation_mask'], (IMG_SIZE, IMG_SIZE))
    
    # rescale the image
    if IMAGE_CHANNELS == 1:
        input_image = tf.image.rgb_to_grayscale(input_image)
    input_image = tf.cast(input_image, tf.float32) / 255.0
    
    # augmentation
    if augment:
        # zoom in a bit
        if tf.random.uniform(()) > 0.5:
            # use original image to preserve high resolution
            input_image = tf.image.central_crop(img_orig, 0.75)
            input_mask = tf.image.central_crop(mask_orig, 0.75)
            # resize
            input_image = tf.image.resize(input_image, (IMG_SIZE, IMG_SIZE))
            input_mask = tf.image.resize(input_mask, (IMG_SIZE, IMG_SIZE))
        
        # random brightness adjustment illumination
        input_image = tf.image.random_brightness(input_image, 0.3)
        # random contrast adjustment
        input_image = tf.image.random_contrast(input_image, 0.2, 0.5)
        
        # flipping random horizontal or vertical
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_left_right(input_image)
            input_mask = tf.image.flip_left_right(input_mask)
        if tf.random.uniform(()) > 0.5:
            input_image = tf.image.flip_up_down(input_image)
            input_mask = tf.image.flip_up_down(input_mask)

        # rotation in 30° steps
        rot_factor = tf.cast(tf.random.uniform(shape=[], maxval=12, dtype=tf.int32), tf.float32)
        angle = np.pi/12*rot_factor
        input_image = tfa.image.rotate(input_image, angle)
        input_mask = tfa.image.rotate(input_mask, angle)

    return input_image, input_mask
like image 110
Tom Avatar answered Oct 26 '25 16:10

Tom