I am training a model in keras and I want to plot graphs of results after each epoch. I know that keras callbacks provide "on_epoch_end" function that can be overloaded if one wants to do some computations after each epoch but my function takes some additional parameters which when given, crashes code by the meta class error. The detail is given below:
Here is how I am doing it right now, which is working fine:-
class NewCallback(Callback):
def on_epoch_end(self, epoch, logs={}): #working fine, printing epoch after each epoch
print("EPOCH IS: "+str(epoch))
epochs=5
batch_size = 16
model_saved=False
if model_saved:
vae.load_weights(args.weights)
else:
# train the autoencoder
vae.fit(x_train,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, None),
callbacks=[NewCallback()])
But I want my callback function like this:-
class NewCallback(Callback,models,data,batch_size):
def on_epoch_end(self, epoch, logs={}):
print("EPOCH IS: "+str(epoch))
x=models.predict(data)
plt.plot(x)
plt.savefig(epoch+".png")
If I call it like this in fit:
callbacks=[NewCallback(models, data, batch_size=batch_size)]
I get this error:
TypeError: metaclass conflict: the metaclass of a derived class must be a (non-strict) subclass of the metaclasses of all its bases
I am looking for a simpler solution to call my function or get this error of meta class resolved, any help will be much appreciated!
I think that what you would like to do is to define a class that descends from callback and takes models, data, etc... as constructor arguments. So:
class NewCallback(Callback):
""" NewCallback descends from Callback
"""
def __init__(self, models, data, batch_size):
""" Save params in constructor
"""
self.models = models
def on_epoch_end(self, epoch, logs={}):
x = self.models.predict(self.data)
In case you want to make predictions on the test data you can try this
class CustomCallback(keras.callbacks.Callback):
def __init__(self, model, x_test, y_test):
self.model = model
self.x_test = x_test
self.y_test = y_test
def on_epoch_end(self, epoch, logs={}):
y_pred = self.model.predict(self.x_test, self.y_test)
print('y predicted: ', y_pred)
You need mention the callback during model.fit
model.sequence()
# your model architecture
model.fit(x_train, y_train, epochs=10,
callbacks=[CustomCallback(model, x_test, y_test)])
Similar to on_epoch_end
there are many other methods provided by keras
on_train_begin, on_train_end, on_epoch_begin, on_epoch_end, on_test_begin,
on_test_end, on_predict_begin, on_predict_end, on_train_batch_begin, on_train_batch_end,
on_test_batch_begin, on_test_batch_end, on_predict_batch_begin,on_predict_batch_end
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With