I'm using the callback function in keras to record the loss
and val_loss
per epoch, But I would like to a do the same but per batch. I found a callback function called on_batch_begin(self,batch,log={})
, but I not sure how to use it.
Here is an example of custom callback. Following and modifying an example from here:
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.losses = []
self.val_losses = []
def on_batch_end(self, batch, logs={}):
self.losses.append(logs.get('loss'))
self.val_losses.append(logs.get('val_loss'))
model = Sequential()
model.add(Dense(10, input_dim=784, init='uniform'))
model.add(Activation('softmax'))
model.compile(loss='categorical_crossentropy', optimizer='rmsprop')
history = LossHistory()
model.fit(X_train, Y_train, batch_size=128, nb_epoch=20, verbose=0, validation_split=0.1,
callbacks=[history])
print history.losses
# outputs
'''
[0.66047596406559383, 0.3547245744908703, ..., 0.25953155204159617, 0.25901699725311789]
'''
print history.val_losses
import numpy as np
import matplotlib.pyplot as plt
import keras
class LossHistory(keras.callbacks.Callback):
def on_train_begin(self, logs={}):
self.history = {'loss':[],'val_loss':[]}
def on_batch_end(self, batch, logs={}):
self.history['loss'].append(logs.get('loss'))
def on_epoch_end(self, epoch, logs={}):
self.history['val_loss'].append(logs.get('val_loss'))
history = LossHistory()
model = keras.Sequential()
model.add(keras.layers.Dense(32, activation='relu', input_dim=100))
model.add(keras.layers.Dense(1, activation='sigmoid'))
model.compile(optimizer='rmsprop', loss='binary_crossentropy')
# Generate dummy data
import numpy as np
data = np.random.random((1000, 100))
labels = np.random.randint(2, size=(1000, 1))
# Train the model, iterating on the data in batches of 32 samples
model.fit(data, labels, epochs=10, batch_size=32,
validation_split=0.2, callbacks=[history])
# Plot the history
y1=history.history['loss']
y2=history.history['val_loss']
x1 = np.arange( len(y1))
k=len(y1)/len(y2)
x2 = np.arange(k,len(y1)+1,k)
fig, ax = plt.subplots()
line1, = ax.plot(x1, y1, label='loss')
line2, = ax.plot(x2, y2, label='val_loss')
plt.show()
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With