Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Python/Keras - accessing ModelCheckpoint callback

Tags:

python

keras

I'm using Keras to predict a time series. As standard I'm using 20 epochs. I want to know what did my neural network predict for each one of the 20 epochs.

By using model.predict I get the last prediction. However I want all predictions, or at least the last 10 ones (which have acceptable error levels).

To access that I'm trying the ModelCheckpoint function from Keras, however I'm having trouble to access it afterwards. I'm using the following code:

model=Sequential()

model.add(GRU(input_dim=col,init='uniform',output_dim=20))
model.add(Dense(10))
model.add(Dense(5))
model.add(Activation("softmax"))
model.add(Dense(1))

model.compile(loss="mae", optimizer="RMSprop")

checkpoint=ModelCheckpoint(filepath='/Users/Alex/checkpoint.hdf5')

model.fit(X=predictor_train, y=target_train, nb_epoch=20, batch_size=batch,validation_split=0.1) #best validation split at 0.1
model.evaluate(X=predictor_train, y=target_train,batch_size=batch,show_accuracy=True)

print checkpoint

Objectively, my questions are:

  • I expected that after running the code I would find a file named checkpoint.hdf5 inside the folder /Users/Alex, however I didn't. What am I missing?

  • When I print checkpoint out what I get is a keras.callbacks.ModelCheckpoint object at 0x117471290. Is there a way to print what I want? How would the code look like?

Your help is very much appreciated :)

like image 969
aabujamra Avatar asked Apr 26 '16 21:04

aabujamra


People also ask

What is Keras callbacks ModelCheckpoint?

ModelCheckpoint callback is used in conjunction with training using model. fit() to save a model or weights (in a checkpoint file) at some interval, so the model or weights can be loaded later to continue the training from the state saved.

How do I use callback in Keras?

Using Callbacks in KerasCallbacks can be provided to the fit() function via the “callbacks” argument. First, callbacks must be instantiated. Then, one or more callbacks that you intend to use must be added to a Python list. Finally, the list of callbacks is provided to the callback argument when fitting the model.


1 Answers

There are two problems in this code:

  • You are not passing the callback to the model's fit method. This is done with the keyword argument "callbacks".
  • The filepath should contain placeholders (like "{epoch:02d}-{val_loss:.2f}" that are used with str.format by Keras in order to save each epoch to a different file.

So the correct version should be something like:

checkpoint = ModelCheckpoint(filepath='/Users/Alex/checkpoint-{epoch:02d}-{val_loss:.2f}.hdf5')

model.fit(X=predictor_train, y=target_train, nb_epoch=20,
         batch_size=batch,validation_split=0.1, callbacks=[checkpoint])

You can also add other kinds of callbacks in the list that is assigned to that keyword.

Unfortunately the callback object doesn't store the history information so it cannot be recovered from it.

like image 183
Dr. Snoopy Avatar answered Oct 21 '22 12:10

Dr. Snoopy