Logo Questions Linux Laravel Mysql Ubuntu Git Menu

parallelising tf.data.Dataset.from_generator

I have a non trivial input pipeline that from_generator is perfect for...

dataset = tf.data.Dataset.from_generator(complex_img_label_generator,                                         (tf.int32, tf.string)) dataset = dataset.batch(64) iter = dataset.make_one_shot_iterator() imgs, labels = iter.get_next() 

Where complex_img_label_generator dynamically generates images and returns a numpy array representing a (H, W, 3) image and a simple string label. The processing not something I can represent as reading from files and tf.image operations.

My question is about how to parallise the generator? How do I have N of these generators running in their own threads.

One thought was to use dataset.map with num_parallel_calls to handle the threading; but the map operates on tensors... Another thought was to create multiple generators each with it's own prefetch and somehow join them, but I can't see how I'd join N generator streams?

Any canonical examples I could follow?

like image 475
mat kelcey Avatar asked Nov 03 '17 00:11

mat kelcey

People also ask

What is TF data dataset?

TensorFlow Datasets is a collection of datasets ready to use, with TensorFlow or other Python ML frameworks, such as Jax. All datasets are exposed as tf. data. Datasets , enabling easy-to-use and high-performance input pipelines. To get started see the guide and our list of datasets.

What does TF data dataset From_tensor_slices do?

With that knowledge, from_tensors makes a dataset where each input tensor is like a row of your dataset, and from_tensor_slices makes a dataset where each input tensor is column of your data; so in the latter case all tensors must be the same length, and the elements (rows) of the resulting dataset are tuples with one ...

How do I know what shape my TF dataset is?

To get the shape of a tensor, you can easily use the tf. shape() function. This method will help the user to return the shape of the given tensor.

What is TF data Autotune?

AUTOTUNE , which will prompt the tf. data runtime to tune the value dynamically at runtime. Note that the prefetch transformation provides benefits any time there is an opportunity to overlap the work of a "producer" with the work of a "consumer."

2 Answers

Turns out I can use Dataset.map if I make the generator super lightweight (only generating meta data) and then move the actual heavy lighting into a stateless function. This way I can parallelise just the heavy lifting part with .map using a py_func.

Works; but feels a tad clumsy... Would be great to be able to just add num_parallel_calls to from_generator :)

def pure_numpy_and_pil_complex_calculation(metadata, label):   # some complex pil and numpy work nothing to do with tf   ...  dataset = tf.data.Dataset.from_generator(lightweight_generator,                                          output_types=(tf.string,   # metadata                                                        tf.string))  # label  def wrapped_complex_calulation(metadata, label):   return tf.py_func(func = pure_numpy_and_pil_complex_calculation,                     inp = (metadata, label),                     Tout = (tf.uint8,    # (H,W,3) img                             tf.string))  # label dataset = dataset.map(wrapped_complex_calulation,                       num_parallel_calls=8)  dataset = dataset.batch(64) iter = dataset.make_one_shot_iterator() imgs, labels = iter.get_next() 
like image 88
mat kelcey Avatar answered Sep 29 '22 16:09

mat kelcey

I am working on a from_indexable for tf.data.Dataset https://github.com/tensorflow/tensorflow/issues/14448

The advantage for from_indexable is that it can be parallelized, while a python generator cannot be parallelized.

The function from_indexable makes a tf.data.range, wraps the indexable in a generalized tf.py_func and calls map.

For those that want now a from_indexable, here the lib code

import tensorflow as tf import numpy as np  from tensorflow.python.framework import tensor_shape from tensorflow.python.util import nest  def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):     def decorator(func):         def call(*args):             nonlocal output_shapes              flat_output_types = nest.flatten(output_types)             flat_values = tf.py_func(                 func,                  inp=args,                  Tout=flat_output_types,                 stateful=stateful, name=name             )             if output_shapes is not None:                 # I am not sure if this is nessesary                 output_shapes = nest.map_structure_up_to(                     output_types, tensor_shape.as_shape, output_shapes)                 flattened_shapes = nest.flatten_up_to(output_types, output_shapes)                 for ret_t, shape in zip(flat_values, flattened_shapes):                     ret_t.set_shape(shape)             return nest.pack_sequence_as(output_types, flat_values)         return call     return decorator  def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):     ds = tf.data.Dataset.range(len(iterator))     @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)     def index_to_entry(index):         return iterator[index]         return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls) 

and here an example (Note: from_indexable has a num_parallel_calls argument)

class PyDataSet:     def __len__(self):         return 20      def __getitem__(self, item):         return np.random.normal(size=(item+1, 10))  ds = from_indexable(PyDataSet(), output_types=tf.float64, output_shapes=[None, 10]) it = ds.make_one_shot_iterator() entry = it.get_next() with tf.Session() as sess:     print(sess.run(entry).shape)     print(sess.run(entry).shape) 

Update June 10, 2018: Since https://github.com/tensorflow/tensorflow/pull/15121 is merged, the code for from_indexable simplifies to:

import tensorflow as tf  def py_func_decorator(output_types=None, output_shapes=None, stateful=True, name=None):     def decorator(func):         def call(*args, **kwargs):             return tf.contrib.framework.py_func(                 func=func,                  args=args, kwargs=kwargs,                  output_types=output_types, output_shapes=output_shapes,                  stateful=stateful, name=name             )         return call     return decorator  def from_indexable(iterator, output_types, output_shapes=None, num_parallel_calls=None, stateful=True, name=None):     ds = tf.data.Dataset.range(len(iterator))     @py_func_decorator(output_types, output_shapes, stateful=stateful, name=name)     def index_to_entry(index):         return iterator[index]         return ds.map(index_to_entry, num_parallel_calls=num_parallel_calls) 
like image 25
Christoph Böddeker Avatar answered Sep 29 '22 17:09

Christoph Böddeker