Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to implement Beholder (Tensorboard plugin) for Keras?

I am trying to implement the Beholder plugin from Tensorboard into a simple CNN code (I am a beginner at Tensorflow), but I am not sure where to put the visualizer.update(session=session). At the beginning I have:

from tensorboard.plugins.beholder import Beholder
LOG_DIRECTORY='/tmp/tensorflow_logs'
visualizer = Beholder(logdir=LOG_DIRECTORY)

I train my model like this:

model = Sequential()
model.add(Conv2D(32, (3, 3), input_shape=(253,27,3))) 
.
.
.
model.compile(loss='binary_crossentropy',
                  optimizer='rmsprop',
                  metrics=['accuracy'])

Where should I put the visualizer.update(session=session) and what else should I put in my code, as for now it says No Beholder data was found. Thank you!

like image 384
Gica Avatar asked Apr 24 '19 23:04

Gica


1 Answers

It would be appropriate to create a custom Keras callback, so that you can call visualizer.update(session=session) at the end of each epoch (or whenever you want). Here is an example showing how such callback could look like:

from tensorboard.plugins.beholder import Beholder
import tensorflow as tf
import keras.backend as K
import keras

LOG_DIRECTORY='/tmp/tensorflow_logs'


class BeholderCallback(keras.callbacks.Callback):
    def __init__(self, tensor, logdir=LOG_DIRECTORY, sess=None):
        self.visualizer = Beholder(logdir=logdir)
        self.sess = sess
        if sess is None:
            self.sess = K.get_session()
        self.tensor = tensor

    def on_epoch_end(self, epoch, logs=None):
        frame = self.sess.run(self.tensor)  # depending on the tensor, this might require a feed_dict
        self.visualizer.update(
            session=self.sess,
            frame=frame
        )

Then, after defining your model, instantiate the callback and pass it to model.fit:

# Define your Keras model
# ...

# Prepare callback
sess = K.get_session() 
beholder_callback = BeholderCallback(your_tensor, sess=sess)

# Fit data into model and pass callback to model.fit
model.fit(x=x_train,
          y=y_train,
          callbacks=[beholder_callback])

You could also use the argument arrays of visualizer.update in a similar way.

like image 100
rvinas Avatar answered Oct 17 '22 18:10

rvinas