Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I combine ImageDataGenerator with TensorFlow datasets in TF2?

I have a TF dataset to classify cats and dogs:

import tensorflow_datasets as tfds
SPLIT_WEIGHTS = (8, 1, 1)
splits = tfds.Split.TRAIN.subsplit(weighted=SPLIT_WEIGHTS)

(raw_train, raw_validation, raw_test), metadata = tfds.load(
    'cats_vs_dogs', split=list(splits),
    with_info=True, as_supervised=True)

In the example they use some image augmentation with a map function. I was wondering if that could also be done with the nice ImageDataGenerator class such as described here:

from tensorflow.keras.preprocessing.image import ImageDataGenerator
train_image_generator = ImageDataGenerator(rescale=1./255) # Generator for our training data
train_data_gen = train_image_generator.flow_from_directory(batch_size=batch_size,
                                                           directory=train_dir,
                                                           shuffle=True,
                                                           target_size=(IMG_HEIGHT, IMG_WIDTH),
                                                           class_mode='binary')

The problem I'm facing is that I can only see 3 ways to use the ImageDataGenerator: pandas dataframe, numpy array and directory of images. Is there any way to also use a Tensorflow dataset and combine these methods?

like image 631
user2874583 Avatar asked Jan 08 '20 15:01

user2874583


1 Answers

Yes, it is but it's a bit tricky.
Keras ImageDataGenerator works on numpy.arrays and not on tf.Tensor's so we have to use Tensorflow's numpy_function. This will allow us to perform operations on tf.data.Dataset content just like it was numpy arrays.

First, let's declare the function that we will .map over our dataset (assuming your dataset consists of image, label pairs):

# We will take 1 original image and create 5 augmented images:
HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

Now, in order to use this function inside tf.data.Dataset we must declare a numpy_function:

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

py_augment can be safely used like:

augmented_dataset_ds = image_label_dataset.map(py_augment)

The image part in dataset is now in shape (HOW_MANY_TO_AUGMENT, image_height, image_width, channels). To convert it to simple (1, image_height, image_width, channels) you can simply use unbatch:

unbatched_augmented_dataset_ds = augmented_dataset_ds.unbatch()

So the entire section looks like this:

HOW_MANY_TO_AUGMENT = 5

def augment(image, label):

  # Create generator and fit it to an image
  img_gen = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1./255)
  img_gen.fit(image)

  # We want to keep original image and label
  img_results = [(image/255.).astype(np.float32)] 
  label_results = [label]

  # Perform augmentation and keep the labels
  augmented_images = [next(img_gen.flow(image)) for _ in range(HOW_MANY_TO_AUGMENT)]
  labels = [label for _ in range(HOW_MANY_TO_AUGMENT)]

  # Append augmented data and labels to original data
  img_results.extend(augmented_images)
  label_results.extend(labels)

  return img_results, label_results

def py_augment(image, label):
  func = tf.numpy_function(augment, [image, label], [tf.float32, tf.int32])
  return func

unbatched_augmented_dataset_ds = augmented_dataset_ds.map(py_augment).unbatch()

# Iterate over the dataset for preview:
for image, label in unbatched_augmented_dataset_ds:
    ...
like image 131
sebastian-sz Avatar answered Sep 19 '22 08:09

sebastian-sz