Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Adding Tensorboard summaries from graph ops generated inside Dataset map() function calls

I've found the Dataset.map() functionality pretty nice for setting up pipelines to preprocess image/audio data before feeding into the network for training, but one issue I have is accessing the raw data before the preprocessing to send to tensorboard as a summary.

For example, say I have a function that loads audio data, does some framing, makes a spectrogram, and returns this.

import tensorflow as tf 

def load_audio_examples(label, path):
    # loads audio, converts to spectorgram
    pcm = ...  # this is what I'd like to put into tf.summmary.audio() !
    # creates one-hot encoded labels, etc
    return labels, examples

# create dataset
training = tf.data.Dataset.from_tensor_slices((
    tf.constant(labels), 
    tf.constant(paths)
))

training = training.map(load_audio_examples, num_parallel_calls=4)

# create ops for training
train_step = # ...
accuracy = # ...

# create iterator
iterator = training.repeat().make_one_shot_iterator()
next_element = iterator.get_next()

# ready session
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
train_writer = # ...

# iterator
test_iterator = testing.make_one_shot_iterator()
test_next_element = iterator.get_next()

# train loop
for i in range(100):
    batch_ys, batch_xs, path = sess.run(next_element)
    summary, train_acc, _ = sess.run([summaries, accuracy, train_step], 
        feed_dict={x: batch_xs, y: batch_ys})
    train_writer.add_summary(summary, i) 

It appears as though this does not become part of the graph that is plotted in the "Graph" tab of tensorboard (see screenshot below).

tesnrofboard

As you can see, it's just X (the output of the preprocessing map() function).

  1. How would I better structure this to get the raw audio into a tf.summary.audio()? Right now the things inside map() aren't accessible as Tensors inside my training loop.
  2. Also, why isn't my graph showing up on Tensorboard? Worries me that I won't be able to export my model or use Tensorflow Serving to put my model into production because I'm using the new Dataset API - maybe I should go back to doing things manually? (with queues, etc).
like image 314
lollercoaster Avatar asked Mar 26 '18 02:03

lollercoaster


1 Answers

I think your use of Dataset API doesn't make much sense. In fact you have 2 disconnected subgraphs. One for reading data and the other for running your training step.

batch_ys, batch_xs, path = sess.run(next_element)
summary, train_acc, _ = sess.run([summaries, accuracy, train_step], 
    feed_dict={x: batch_xs, y: batch_ys})

The first line in the code above runs session and fetches data items from it. It transfers data from Tensorflow backend into Python.

The next line feeds data using feed_dict and that is said to be inefficient. This time TensorFlow transfers data from Python to runtime.

This has the following consequences:

  1. Your graph looks disconnected
  2. TensorFlow wastes time doing unnecessary data transfer to and from Python.

To have a single graph (without disconnected subgraphs) you need to build your model on top of tensors returned by Dataset API. Please note that it is possible to switch between training and testing datasets without manual fetching of batches (see Dataset guide)

If to speak about summary defined in map_fn I believe you can retrieve summary from SUMMARIES collection (default collection for summaries). You can also pass your own collection name when adding summary operation.

like image 56
dm0_ Avatar answered Nov 13 '22 16:11

dm0_