Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow Dataset .map() API

Couple of questions about this

For occasions when I'd like to do something like the following in Tensorflow (assume I'm creating training examples by loading WAV files):

import tensorflow as tf 

def _some_audio_preprocessing_func(filename):
   # ... some logic here which mostly uses Tensorflow ops ...
   with tf.Session(graph=tf.Graph()) as sess:
        wav_filename_placeholder = tf.placeholder(tf.string, [])
        wav_loader = io_ops.read_file(wav_filename_placeholder)
        wav_decoder = contrib_audio.decode_wav(wav_loader, desired_channels=1)
        data = sess.run(
                [wav_decoder],
                feed_dict={wav_filename_placeholder: filename})
        return data

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)
  1. If I have a parse_image() function that uses tensor ops - should this be part of the main Graph? Following the example set in Google's own audio TF tutorial, it looks like they create a separate graph! Doesn't this ruin the point of using Tensorflow to make things faster?
  2. Do I use tf.py_func() any time any single line isn't from the tensorflow library? Again, I wonder what the performance implications are and when I should use this...

Thanks!

like image 488
lollercoaster Avatar asked Mar 14 '18 05:03

lollercoaster


People also ask

What is TensorFlow data API?

The tf. data API enables you to build complex input pipelines from simple, reusable pieces. For example, the pipeline for an image model might aggregate data from files in a distributed file system, apply random perturbations to each image, and merge randomly selected images into a batch for training.

What is .MAP in TensorFlow?

TensorFlow map() method of tf. data. Dataset used for transforming items in a dataset, refer below snippet for map() use. This code snippet is using TensorFlow2. 0, if you are using earlier versions of TensorFlow than enable execution to run the code.

What does dataset map do?

Dataset. map() function is used to map the dataset through a 1-to-1 transform.


1 Answers

When you use Dataset.map(map_func), TensorFlow defines a subgraph for all the ops created in the function map_func, and arranges to execute it efficiently in the same session as the rest of your graph. There is almost never any need to create a tf.Graph or tf.Session inside map_func: if your parsing function is made up of TensorFlow ops, these ops can be embedded directly in the graph that defines the input pipeline.

The modified version of the code using tf.data would look like this:

import tensorflow as tf 
from tensorflow.contrib.framework.python.ops import audio_ops as contrib_audio

def _some_audio_preprocessing_func(filename):
    wav_loader = tf.read_file(filename)
    return contrib_audio.decode_wav(wav_loader, desired_channels=1)

dataset = tf.data.Dataset.list_files('*.wav')
dataset = dataset.map(_some_preprocessing_func)

If your map_func contains non-TensorFlow operations that you want to apply to each element, you should wrap them in a tf.py_func() (or Dataset.from_generator(), if the data generation process is defined in Python logic). The main performance implication is that any code running in a tf.py_func() is subject to the Global Interpreter Lock, so I would generally recommend trying to find a native TensorFlow implementation for anything that is performance critical.

like image 97
mrry Avatar answered Sep 19 '22 01:09

mrry