Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What is the canonical way to split tf.Dataset into test and validation subsets?

Problem

I was following a Tensorflow 2 tutorial on how to load images with pure Tensorflow, because it is supposed to be faster than with Keras. The tutorial ends before showing how to split the resulting dataset (~tf.Dataset) into a train and validation dataset.

  • I checked the reference for tf.Dataset and it does not contain a split() method.

  • I tried slicing it manually but tf.Dataset neither contains a size() nor a length() method, so I don't see how I could slice it myself.

  • I can't use the validation_split argument of Model.fit() because I need to augment the training dataset but not the validation dataset.

Question

What is the intended way to split a tf.Dataset or should I use a different workflow where I won't have to do this?

Example Code

(from the tutorial)

BATCH_SIZE = 32
IMG_HEIGHT = 224
IMG_WIDTH = 224


list_ds = tf.data.Dataset.list_files(str(data_dir/'*/*'))


def get_label(file_path):
  # convert the path to a list of path components
  parts = tf.strings.split(file_path, os.path.sep)
  # The second to last is the class-directory
  return parts[-2] == CLASS_NAMES


def decode_img(img):
  # convert the compressed string to a 3D uint8 tensor
  img = tf.image.decode_jpeg(img, channels=3)
  # Use `convert_image_dtype` to convert to floats in the [0,1] range.
  img = tf.image.convert_image_dtype(img, tf.float32)
  # resize the image to the desired size.
  return tf.image.resize(img, [IMG_WIDTH, IMG_HEIGHT])


def process_path(file_path):
  label = get_label(file_path)
  # load the raw data from the file as a string
  img = tf.io.read_file(file_path)
  img = decode_img(img)
  return img, label


labeled_ds = list_ds.map(process_path, num_parallel_calls=AUTOTUNE)
#...
#...

I can either split list_ds (list of files) or labeled_ds (list of images and labels), but how?

like image 643
problemofficer - n.f. Monica Avatar asked Jan 09 '20 17:01

problemofficer - n.f. Monica


2 Answers

I don't think there's a canonical way (typically, data is being split e.g. in separate directories). But here's a recipe that will let you do it dynamically:

# Caveat: cache list_ds, otherwise it will perform the directory listing twice.
ds = list_ds.cache()

# Add some indices.
ds = ds.enumerate()

# Do a rougly 70-30 split.
train_list_ds = ds.filter(lambda i, data: i % 10 < 7)
test_list_ds = ds.filter(lambda i, data: i % 10 >= 7)

# Drop indices.
train_list_ds = train_list_ds.map(lambda i, data: data)
test_list_ds = test_list_ds.map(lambda i, data: data)
like image 119
Dan Moldovan Avatar answered Oct 20 '22 14:10

Dan Moldovan


Based on Dan Moldovan's answer I created a reusable function. Maybe this is useful to other people.

def split_dataset(dataset: tf.data.Dataset, validation_data_fraction: float):
    """
    Splits a dataset of type tf.data.Dataset into a training and validation dataset using given ratio. Fractions are
    rounded up to two decimal places.
    @param dataset: the input dataset to split.
    @param validation_data_fraction: the fraction of the validation data as a float between 0 and 1.
    @return: a tuple of two tf.data.Datasets as (training, validation)
    """

    validation_data_percent = round(validation_data_fraction * 100)
    if not (0 <= validation_data_percent <= 100):
        raise ValueError("validation data fraction must be ∈ [0,1]")

    dataset = dataset.enumerate()
    train_dataset = dataset.filter(lambda f, data: f % 100 > validation_data_percent)
    validation_dataset = dataset.filter(lambda f, data: f % 100 <= validation_data_percent)

    # remove enumeration
    train_dataset = train_dataset.map(lambda f, data: data)
    validation_dataset = validation_dataset.map(lambda f, data: data)

    return train_dataset, validation_dataset
like image 42
problemofficer - n.f. Monica Avatar answered Oct 20 '22 14:10

problemofficer - n.f. Monica