I've trained a CNN model in TensorFlow eager mode. Now I'm trying to restore the trained model from a checkpoint file but haven't got any success.
All the examples (as shown below) I've found are talking about restoring checkpoint to a Session. But what I need is to restore the model into eager mode, i.e. without creating a session.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/tmp/model.ckpt")
Basically what I need is something like:
tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)
and then I can use the model to make predictions.
Can someone please help?
Update
The example code can be found at: mnist eager mode demo
I've tried to follow the steps from @Jay Shah 's answer and it almost worked but the restored model doesn't have any variables in it.
tfe.save_network_checkpoint(model,'./test/my_model.ckpt')
Out[58]:
'./test/my_model.ckpt-1720'
model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables
Out[72]:
[]
The original model has lots of variables in it.:
model.variables
[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
array([[[[ -8.25184360e-02, 6.77833706e-03, 6.97569922e-02,...
Eager Execution is still a new feature in TensorFlow, and was not included in the latest version, so not all features, are supported, but fortunately, loading a model from a saved checkpoint is.
You'll need to use the tfe.Saver class (which is a thin wrapper over the tf.train.Saver class), and your code should look something like this:
saver = tfe.Saver([x, y])
saver.restore('/tmp/ckpt')
Where [x,y] represents the list of variables and/or models you wish to restore. This should precisely match the variables passed when the saver that created the checkpoint was initially created.
More details, including sample code, can be found here, and the API details of the saver can be found here.
Ok, after spending a few hours running the code in line-by-line mode, I've figured out a way to restore a checkpoint to a new TensorFlow Eager Mode model.
Using the examples from TF Eager Mode MNIST
Steps:
After your model has been trained, find the latest checkpoint(or the checkpoint you want) index file from the checkpoint folder created in the training process, such as 'ckpt-25800.index'. Use only the filename 'ckpt-25800' while restoring in step 5.
Start a new python terminal and enable TensorFlow Eager mode by running:
tfe.enable_eager_execution()
Create a new instance of the MNISTMOdel:
model_new = MNISTModel()
Initialise the variables for model_new by running a dummy train process once.(This step is important. Without initialising the variables first, they can't be restored by the following step. However I can't find another way to initialise variables in Eager mode other than what I did below.)
model_new(tfe.Variable(np.zeros((1,784),dtype=np.float32)), training=True)
Restore the variables to model_new using the checkpoint identified in step 1.
tfe.Saver((model_new.variables)).restore('./tf_checkpoints/ckpt-25800')
If restore process is successful, you should see something like:
INFO:tensorflow:Restoring parameters from ./tf_checkpoints/ckpt-25800
Now the checkpoint has been successfully restored to model_new and you can use it to make predictions on new data.
I like to share TFLearn library which is Deep learning library featuring a higher-level API for TensorFlow
. With the help of this library you can easily save and restore
a model.
Saving a model
model = tflearn.DNN(net) #Here 'net' is your designed network model.
#This is a sample example for training the model
model.fit(train_x, train_y, n_epoch=10, validation_set=(test_x, test_y), batch_size=10, show_metric=True)
model.save("model_name.ckpt")
Restore a model
model = tflearn.DNN(net)
model.load("model_name.ckpt")
For more example of tflearn
you can check some site like...
saver.save(sess, './my_model.ckpt')
Following code restores the model
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './my_model.ckpt')
For eager mode to save :
tf.contrib.eager.save_network_checkpoint(sess,'./my_model.ckpt')
For eager mode to restore :
tf.contrib.eager.restore_network_checkpoint(sess,'./my_model.ckpt')
sess is an object of class Network. Any object of class Network can be saved and restored. A quick explanation of network objects :-
class TwoLayerNetwork(tfe.Network):
def __init__(self, name):
super(TwoLayerNetwork, self).__init__(name=name)
self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,)))
self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,)))
def call(self, inputs):
return self.layer_two(self.layer_one(inputs))
After constructing an object and calling the Network
, a list of variables
created by tracked Layer
s is available via Network.variables
:
python
sess = TwoLayerNetwork(name="net") # sess is object of Network
output = sess(tf.ones([1, 8]))
print([v.name for v in sess.variables])
```
=================================================================
This example prints variable names, one kernel and one bias per
`tf.layers.Dense` layer:
['net/dense/kernel:0',
'net/dense/bias:0',
'net/dense_1/kernel:0',
'net/dense_1/bias:0']
These variables can be passed to a `Saver` (`tf.train.Saver`, or
`tf.contrib.eager.Saver` when executing eagerly) to save or restore the
`Network`
=================================================================
```
tfe.save_network_checkpoint(sess,'./my_model.ckpt') # saving the model
tfe.restore_network_checkpoint(sess,'./my_model.ckpt') # restoring
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With