Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Get Keras model input from inside a custom callback

I have a very simple question. I have a Keras model (TF backend) defined for classification. I want to dump the training images fed into my model during training for debugging purposes. I am trying to create a custom callback that writes Tensorboard image summaries for this.

But how can I obtain the real training data inside the callback?

Currently I am trying this:

class TensorboardKeras(Callback):                                                                                                                                                                                                                                     
    def __init__(self, model, log_dir, write_graph=True):                                                                                                                                                                                                             
        self.model = model                                                                                                                                                                                                                                            
        self.log_dir = log_dir                                                                                                                                                                                                                                        
        self.session = K.get_session()                                                                                                                                                                                                                                

        tf.summary.image('input_image', self.model.input)                                                                                                                                                                                                             
        self.merged = tf.summary.merge_all()                                                                                                                                                                                                                          

        if write_graph:                                                                                                                                                                                                                                               
            self.writer = tf.summary.FileWriter(self.log_dir, K.get_session().graph)                                                                                                                                                                                  
        else:                                                                                                                                                                                                                                                         
            self.writer = tf.summary.FileWriter(self.log_dir)

    def on_batch_end(self, batch, logs=None):
        summary = self.session.run(self.merged, feed_dict={})                                                                                                                                                                                                         
        self.writer.add_summary(summary, batch)                                                                                                                                                                                                                       
        self.writer.flush()

But I am getting the error: InvalidArgumentError (see above for traceback): You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]

There must be a way to see what models, get as an input, right?

Or maybe I should try another way to debug it?

like image 675
Dmytro Prylipko Avatar asked Oct 12 '18 16:10

Dmytro Prylipko


People also ask

How do I use callback in Keras?

Using Callbacks in KerasCallbacks can be provided to the fit() function via the “callbacks” argument. First, callbacks must be instantiated. Then, one or more callbacks that you intend to use must be added to a Python list. Finally, the list of callbacks is provided to the callback argument when fitting the model.

What is On_epoch_end?

on_epoch_end : this is triggered when an epoch ends. on_batch_begin : this is triggered when a new batch is passed for training. on_batch_end : when a batch is finished with training. on_train_begin : when the training starts. on_train_end : when the training ends.

What is Keras callback ModelCheckpoint used for?

ModelCheckpoint callback is used in conjunction with training using model. fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.

What is callback in model fit Keras?

A callback is an object that can perform actions at various stages of training (e.g. at the start or end of an epoch, before or after a single batch, etc). You can use callbacks to: Write TensorBoard logs after every batch of training to monitor your metrics. Periodically save your model to disk.


1 Answers

You don't need callbacks for this. All you need to do is implementing a function that yields an image and its label as a tuple. flow_from_directory function has a parameter called save_to_dir which could satisfy all of your needs, in case it doesn't, here is what you can do:

def trainGenerator(batch_size,train_path, image_size)
    #preprocessing see https://keras.io/preprocessing/image/ for details
    image_datagen = ImageDataGenerator(horizontal_flip=True)
    #create image generator see https://keras.io/preprocessing/image/#flow_from_directory for details
    train_generator = image_datagen.flow_from_directory(
        train_path,
        class_mode = "categorical",
        target_size = image_size,
        batch_size = batch_size,
        save_prefix  = "augmented_train",
        seed = seed)

    for (batch_imgs, batch_labels) in train_generator: 
        #do other stuff such as dumping images or further augmenting images
    yield (batch_imgs,batch_labels)


t_generator = trainGenerator(32, "./train_data", (224,224,3))
model.fit_generator(t_generator,steps_per_epoch=10,epochs=1)
like image 118
Mete Han Kahraman Avatar answered Sep 28 '22 01:09

Mete Han Kahraman