Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow batch loss spikes when restoring model for training from saved checkpoint?

Tags:

tensorflow

I'm encountering a strange issue that I've been trying to debug, without much luck. My model starts training properly with batch loss decreasing consistently (from ~6000 initially to ~120 after 20 epochs). However, when I pause training and resume training later by restoring the model from the checkpoint, the batch loss seems to spike unexpectedly from the previous batch loss (before pausing), and resumes decreasing from that higher loss point. My worry is that when I restore the model for evaluation, I may not be using the trained model that I think I am.

I have combed over my code several times, comparing to the Tensorflow tutorials. I tried to ensure that I was saving and restoring using the tutorial-suggested methods. Here is the code snapshot: https://github.com/KaranKash/DigitSpeak/tree/b7dad3128c88061ee374ae127579ec25cc7f5286 - the train.py file contains the saving and restoring steps, the graph setup and training process; while model.py creates the network layers and computes loss.

Here is an example from my print statements - notice batch loss rises sharply when resuming training from epoch 7's checkpoint:

Epoch 6. Batch 31/38. Loss 171.28
Epoch 6. Batch 32/38. Loss 167.02
Epoch 6. Batch 33/38. Loss 173.29
Epoch 6. Batch 34/38. Loss 159.76
Epoch 6. Batch 35/38. Loss 164.17
Epoch 6. Batch 36/38. Loss 161.57
Epoch 6. Batch 37/38. Loss 165.40
Saving to /Users/user/DigitSpeak/cnn/model/model.ckpt
Epoch 7. Batch 0/38. Loss 169.99
Epoch 7. Batch 1/38. Loss 178.42
KeyboardInterrupt
dhcp-18-189-118-233:cnn user$ python train.py
Starting loss calculation...
Found in-progress model. Will resume from there.
Epoch 7. Batch 0/38. Loss 325.97
Epoch 7. Batch 1/38. Loss 312.10
Epoch 7. Batch 2/38. Loss 295.61
Epoch 7. Batch 3/38. Loss 306.96
Epoch 7. Batch 4/38. Loss 290.58
Epoch 7. Batch 5/38. Loss 275.72
Epoch 7. Batch 6/38. Loss 251.12

I've printed the results of the inspect_checkpoint.py script. I've also experimented with other loss functions (Adam and GradientDescentOptimizer) and noticed the same behavior with respect to spiked loss after resuming training.

dhcp-18-189-118-233:cnn user$ python inspect_checkpoint.py
Optimizer/Variable (DT_INT32) []
conv1-layer/bias (DT_FLOAT) [64]
conv1-layer/bias/Momentum (DT_FLOAT) [64]
conv1-layer/weights (DT_FLOAT) [5,23,1,64]
conv1-layer/weights/Momentum (DT_FLOAT) [5,23,1,64]
conv2-layer/bias (DT_FLOAT) [512]
conv2-layer/bias/Momentum (DT_FLOAT) [512]
conv2-layer/weights (DT_FLOAT) [5,1,64,512]
conv2-layer/weights/Momentum (DT_FLOAT) [5,1,64,512]
like image 445
kashkar Avatar asked Oct 30 '22 12:10

kashkar


1 Answers

I ran into this issue and found it was the fact that I was initializing the graph variables when restoring the graph -- throwing away all learned parameters, to be replaced with whatever initialization values were originally specified for each respective tensor in the original graph definition.

For example, if you used tf.global_variable_initializer() to initialize variables as part of your model program, whatever your control logic to indicate that a saved graph will be restored, make sure the graph restore flow omits: sess.run(tf.global_variable_initializer())

This was a simple, but costly mistake for me, so I hope someone else is saved a few grey hairs (or hairs, in general).

like image 162
MrTallz Avatar answered Nov 15 '22 12:11

MrTallz