Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to create tf.data.dataset from directories of tfrecords?

My dataset has different directories and each directory is corresponding to one class. There are different numbers of .tfrecords in each directory. My question is that how can I sample 5 images (each .tfrecord file corresponds to one image) from each directory? My other question is that how can I sample 5 of these directories and then sample 5 images from each.

I just want to do it with tf.data.dataset. So I want to have a dataset from which I get an iterator and that iterator.next() gives me a batch of 25 images containing 5 samples from 5 classes.

like image 810
Siavash Avatar asked Dec 03 '22 20:12

Siavash


1 Answers

EDIT: If the number of classes is greater than 5, then you can use the new tf.contrib.data.sample_from_datasets() API (currently available in tf-nightly and will be available in TensorFlow 1.9).

directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", ...]

CLASSES_PER_BATCH = 5
EXAMPLES_PER_CLASS_PER_BATCH = 5
BATCH_SIZE = CLASSES_PER_BATCH * EXAMPLES_PER_CLASS_PER_BATCH
NUM_CLASSES = len(directories)


# Build one dataset per class.
per_class_datasets = [
    tf.data.TFRecordDataset(tf.data.Dataset.list_files(d)) for d in directories]

# Next, build a dataset where each element is a vector of 5 classes to be chosen
# for a particular batch.
classes_per_batch_dataset = tf.contrib.data.Counter().map(
    lambda _: tf.random_shuffle(tf.range(NUM_CLASSES))[:CLASSES_PER_BATCH]))

# Transform the dataset of per-batch class vectors into a dataset with one
# one-hot element per example (i.e. 25 examples per batch).
class_dataset = classes_per_batch_dataset.flat_map(
    lambda classes: tf.data.Dataset.from_tensor_slices(
        tf.one_hot(classes, num_classes)).repeat(EXAMPLES_PER_CLASS_PER_BATCH))

# Use `tf.contrib.data.sample_from_datasets()` to select an example from the
# appropriate dataset in `per_class_datasets`.
example_dataset = tf.contrib.data.sample_from_datasets(per_class_datasets,
                                 class_dataset)

# Finally, combine 25 consecutive examples into a batch.
result = example_dataset.batch(BATCH_SIZE)

If you have exactly 5 classes, you can define a nested dataset for each directory and combine them using Dataset.interleave():

# NOTE: We're assuming that the 0th directory contains elements from class 0, etc.
directories = ["class_0/*", "class_1/*", "class_2/*", "class_3/*", "class_4/*"]
directories = tf.data.Dataset.from_tensor_slices(directories)
directories = directories.apply(tf.contrib.data.enumerate_dataset())    

# Define a function that maps each (class, directory) pair to the (shuffled)
# records in those files.
def per_directory_dataset(class_label, directory_glob):
  files = tf.data.Dataset.list_files(directory_glob, shuffle=True)
  records = tf.data.TFRecordDataset(records)
  # Zip the records with their class. 
  # NOTE: This part might not be necessary if the records contain information about
  # their class that can be parsed from them.
  return tf.data.Dataset.zip(
      (records, tf.data.Dataset.from_tensors(class_label).repeat(None)))

# NOTE: The `cycle_length` and `block_length` here aren't strictly necessary,
# because the batch size is exactly `number of classes * images per class`.
# However, these arguments may be useful if you want to decouple these numbers.
merged_records = directories.interleave(per_directory_dataset,
                                        cycle_length=5, block_length=5)
merged_records = merged_records.batch(25)
like image 191
mrry Avatar answered Jan 08 '23 04:01

mrry