Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Use custom function with custom parameters in keras callback

I am training a model in keras and I want to plot graphs of results after each epoch. I know that keras callbacks provide "on_epoch_end" function that can be overloaded if one wants to do some computations after each epoch but my function takes some additional parameters which when given, crashes code by the meta class error. The detail is given below:

Here is how I am doing it right now, which is working fine:-

class NewCallback(Callback):

def on_epoch_end(self, epoch, logs={}):  #working fine, printing epoch after each epoch
    print("EPOCH IS: "+str(epoch))


epochs=5
batch_size = 16
model_saved=False
if model_saved:
    vae.load_weights(args.weights)
else:
    # train the autoencoder
    vae.fit(x_train,
            epochs=epochs,
            batch_size=batch_size,
            validation_data=(x_test, None),
           callbacks=[NewCallback()])

But I want my callback function like this:-

class NewCallback(Callback,models,data,batch_size):
   def on_epoch_end(self, epoch, logs={}):
     print("EPOCH IS: "+str(epoch))
     x=models.predict(data)
     plt.plot(x)
     plt.savefig(epoch+".png")

If I call it like this in fit:

callbacks=[NewCallback(models, data, batch_size=batch_size)]

I get this error:

TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases 

I am looking for a simpler solution to call my function or get this error of meta class resolved, any help will be much appreciated!

like image 411
Asim Avatar asked Sep 02 '25 15:09

Asim


2 Answers

I think that what you would like to do is to define a class that descends from callback and takes models, data, etc... as constructor arguments. So:

class NewCallback(Callback):
    """ NewCallback descends from Callback
    """
    def __init__(self, models, data, batch_size):
        """ Save params in constructor
        """
        self.models = models

    def on_epoch_end(self, epoch, logs={}):
        x = self.models.predict(self.data)
like image 182
Pedro Marques Avatar answered Sep 05 '25 14:09

Pedro Marques


In case you want to make predictions on the test data you can try this

class CustomCallback(keras.callbacks.Callback):
    def __init__(self, model, x_test, y_test):
        self.model = model
        self.x_test = x_test
        self.y_test = y_test

    def on_epoch_end(self, epoch, logs={}):
        y_pred = self.model.predict(self.x_test, self.y_test)
        print('y predicted: ', y_pred)

You need mention the callback during model.fit

model.sequence()
# your model architecture
model.fit(x_train, y_train, epochs=10, 
          callbacks=[CustomCallback(model, x_test, y_test)])

Similar to on_epoch_end there are many other methods provided by keras

on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_test_begin,
on_test_end, on_predict_begin, on_predict_end, on_train_batch_begin, on_train_batch_end,
on_test_batch_begin, on_test_batch_end, on_predict_batch_begin,on_predict_batch_end
like image 35
Harshal Deore Avatar answered Sep 05 '25 15:09

Harshal Deore