Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Keras: use Tensorboard with train_on_batch()

Tags:

For the keras functions fit() and fit_generator() there is the possibility of tensorboard visualization by passing a keras.callbacks.TensorBoard object to the functions. For the train_on_batch() function there obviously are no callback available. Are there other options in keras to create a Tensorboard in this case?

like image 400
Oliver Wilken Avatar asked Jul 01 '17 12:07

Oliver Wilken


2 Answers

A possible way to create the TensorBoard callback, and drive it manually:

# This example shows how to use keras TensorBoard callback # with model.train_on_batch  import tensorflow.keras as keras  # Setup the model model = keras.models.Sequential() model.add(...) # Add your layers model.compile(...) # Compile as usual  batch_size=256  # Create the TensorBoard callback, # which we will drive manually tensorboard = keras.callbacks.TensorBoard(   log_dir='/tmp/my_tf_logs',   histogram_freq=0,   batch_size=batch_size,   write_graph=True,   write_grads=True ) tensorboard.set_model(model)  # Transform train_on_batch return value # to dict expected by on_batch_end callback def named_logs(model, logs):   result = {}   for l in zip(model.metrics_names, logs):     result[l[0]] = l[1]   return result  # Run training batches, notify tensorboard at the end of each epoch for batch_id in range(1000):   x_train,y_train = create_training_data(batch_size)   logs = model.train_on_batch(x_train, y_train)   tensorboard.on_epoch_end(batch_id, named_logs(model, logs))  tensorboard.on_train_end(None) 
like image 78
erenon Avatar answered Sep 27 '22 21:09

erenon


I think that currently, the only option is to use TensorFlow code. In this stackoverflow answer I found a way to create a TensorBoard log manually.
Thus a code sample with the Keras train_on_batch() could look like this:

# before training init writer (for tensorboard log) / model writer = tf.summary.FileWriter(...) model = ...  # train model loss = model.train_on_batch(...) summary = tf.Summary(value=[tf.Summary.Value(tag="loss",                                               simple_value=value), ]) writer.add_summary(summary) 

Note: For this example in TensorBoard you have to choose Horizontal Axis "RELATIVE" as no step is passed to the summary.

like image 26
Oliver Wilken Avatar answered Sep 27 '22 22:09

Oliver Wilken