I've found the Dataset.map() functionality pretty nice for setting up pipelines to preprocess image/audio data before feeding into the network for training, but one issue I have is accessing the raw data before the preprocessing to send to tensorboard as a summary.
For example, say I have a function that loads audio data, does some framing, makes a spectrogram, and returns this.
import tensorflow as tf
def load_audio_examples(label, path):
# loads audio, converts to spectorgram
pcm = ... # this is what I'd like to put into tf.summmary.audio() !
# creates one-hot encoded labels, etc
return labels, examples
# create dataset
training = tf.data.Dataset.from_tensor_slices((
tf.constant(labels),
tf.constant(paths)
))
training = training.map(load_audio_examples, num_parallel_calls=4)
# create ops for training
train_step = # ...
accuracy = # ...
# create iterator
iterator = training.repeat().make_one_shot_iterator()
next_element = iterator.get_next()
# ready session
sess = tf.InteractiveSession()
tf.global_variables_initializer().run()
train_writer = # ...
# iterator
test_iterator = testing.make_one_shot_iterator()
test_next_element = iterator.get_next()
# train loop
for i in range(100):
batch_ys, batch_xs, path = sess.run(next_element)
summary, train_acc, _ = sess.run([summaries, accuracy, train_step],
feed_dict={x: batch_xs, y: batch_ys})
train_writer.add_summary(summary, i)
It appears as though this does not become part of the graph that is plotted in the "Graph" tab of tensorboard (see screenshot below).
As you can see, it's just X (the output of the preprocessing map() function).
tf.summary.audio()
? Right now the things inside map() aren't accessible as Tensors inside my training loop. I think your use of Dataset API doesn't make much sense. In fact you have 2 disconnected subgraphs. One for reading data and the other for running your training step.
batch_ys, batch_xs, path = sess.run(next_element)
summary, train_acc, _ = sess.run([summaries, accuracy, train_step],
feed_dict={x: batch_xs, y: batch_ys})
The first line in the code above runs session and fetches data items from it. It transfers data from Tensorflow backend into Python.
The next line feeds data using feed_dict
and that is said to be inefficient. This time TensorFlow transfers data from Python to runtime.
This has the following consequences:
To have a single graph (without disconnected subgraphs) you need to build your model on top of tensors returned by Dataset API. Please note that it is possible to switch between training and testing datasets without manual fetching of batches (see Dataset guide)
If to speak about summary defined in map_fn
I believe you can retrieve summary from SUMMARIES
collection (default collection for summaries). You can also pass your own collection name when adding summary operation.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With