Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TF data API: how to efficiently sample small patches from images

Consider the problem of creating a dataset of sampling random small image patches from a directory of high-resolution images. The Tensorflow dataset API allows for a very easy way of doing this, by constructing a dataset of image names, shuffling them, mapping it to loaded images, then to random cropped patches.

However, this naive implementation is very inefficient as a separate high-resolution image will be loaded and cropped to generate each patch. Ideally an image could be loaded once and reused to generate many patches.

One simple way that was discussed previously is to generate multiple patches from an image and flatten them. However this has the unfortunate effect of biasing the data too much. We want each training batch to come from different images.

Ideally what I would like is a "random caching filter" transformation that takes an underlying dataset and caches N elements of it into memory. Its iterator will return a random element from the cache. Also, with pre-defined frequency it will replace a random element from the cache with a new one from the underlying dataset. This filter will allow for faster data access at the expense of less randomization and higher memory consumption.

Is there such functionality available?

If not, should it be implemented as a new dataset transformation or simply a new iterator? It seems a new iterator is all that is needed. Any pointers on how to create a new dataset iterator, ideally in C++?

like image 289
Dimofeevich Avatar asked Feb 14 '18 00:02

Dimofeevich


People also ask

What does prefetch do in TensorFlow?

Prefetching overlaps the preprocessing and model execution of a training step. While the model is executing training step s , the input pipeline is reading the data for step s+1 . Doing so reduces the step time to the maximum (as opposed to the sum) of the training and the time it takes to extract the data.

Is TensorFlow dataset lazy?

Data in Dataset API is lazy loaded, so it depends on later operations. Now you load 1024 samples at time because of the size of shuffle buffer. It needs to fill the shuffle buffer. Data will be then loaded lazily, when you will be fetching values from the iterator.

Does TF data use GPU?

Stay organized with collections Save and categorize content based on your preferences. TensorFlow code, and tf. keras models will transparently run on a single GPU with no code changes required.

What TF data dataset From_tensor_slices do?

from_tensor_slices() It removes the first dimension and use it as a dataset dimension.


1 Answers

You should be able to use tf.data.Dataset.shuffle to achieve what you want. Here is a quick summary for the objectives:

  • load very big images, produce smaller random crops from the images and batch them together
  • make the pipeline efficient by creating multiple patches from a big image once the image is loaded
  • add enough shuffle so that a batch of patches is diverse (all the patches come from different images)
  • don't load too many big images in cache

You can achieve all that using the tf.data API by doing the following steps:

  1. shuffle the filenames of the big images
  2. read the big images
  3. generate multiple patches from this image
  4. shuffle again all these patches with a big enough buffer size (see this answer on buffer size). Adjusting the buffer size is a tradeoff between good shuffling and size of the cached patches
  5. batch them
  6. prefetch one batch

Here is a the relevant code:

filenames = ...  # filenames containing the big images
num_samples = len(filenames)

# Parameters
num_patches = 100               # number of patches to extract from each image
patch_size = 32                 # size of the patches
buffer_size = 50 * num_patches  # shuffle patches from 50 different big images
num_parallel_calls = 4          # number of threads
batch_size = 10                 # size of the batch

get_patches_fn = lambda image: get_patches(image, num_patches=num_patches, patch_size=patch_size)

# Create a Dataset serving batches of random patches in our images
dataset = (tf.data.Dataset.from_tensor_slices(filenames)
    .shuffle(buffer_size=num_samples)  # step 1: all the  filenames into the buffer ensures good shuffling
    .map(parse_fn, num_parallel_calls=num_parallel_calls)  # step 2
    .map(get_patches_fn, num_parallel_calls=num_parallel_calls)  # step 3
    .apply(tf.contrib.data.unbatch())  # unbatch the patches we just produced
    .shuffle(buffer_size=buffer_size)  # step 4
    .batch(batch_size)  # step 5
    .prefetch(1)  # step 6: make sure you always have one batch ready to serve
)

iterator = dataset.make_one_shot_iterator()
patches = iterator.get_next()  # shape [None, patch_size, patch_size, 3]


sess = tf.Session()
res = sess.run(patches)

The functions parse_fn and get_patches are defined like this:

def parse_fn(filename):
    """Decode the jpeg image from the filename and convert to [0, 1]."""
    image_string = tf.read_file(filename)

    # Don't use tf.image.decode_image, or the output shape will be undefined
    image_decoded = tf.image.decode_jpeg(image_string, channels=3)

    # This will convert to float values in [0, 1]
    image = tf.image.convert_image_dtype(image_decoded, tf.float32)

    return image


def get_patches(image, num_patches=100, patch_size=16):
    """Get `num_patches` random crops from the image"""
    patches = []
    for i in range(num_patches):
        patch = tf.image.random_crop(image, [patch_size, patch_size, 3])
        patches.append(patch)

    patches = tf.stack(patches)
    assert patches.get_shape().dims == [num_patches, patch_size, patch_size, 3]

    return patches
like image 198
Olivier Moindrot Avatar answered Nov 04 '22 12:11

Olivier Moindrot