Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

TensorFlow Eager Mode: How to restore a model from a checkpoint?

I've trained a CNN model in TensorFlow eager mode. Now I'm trying to restore the trained model from a checkpoint file but haven't got any success.

All the examples (as shown below) I've found are talking about restoring checkpoint to a Session. But what I need is to restore the model into eager mode, i.e. without creating a session.

with tf.Session() as sess:
  # Restore variables from disk.
  saver.restore(sess, "/tmp/model.ckpt")

Basically what I need is something like:

tfe.enable_eager_execution()
model = tfe.restore('model.ckpt')
model.predict(...)

and then I can use the model to make predictions.

Can someone please help?

Update

The example code can be found at: mnist eager mode demo

I've tried to follow the steps from @Jay Shah 's answer and it almost worked but the restored model doesn't have any variables in it.

tfe.save_network_checkpoint(model,'./test/my_model.ckpt')

Out[58]:
'./test/my_model.ckpt-1720'

model2 = MNISTModel()
tfe.restore_network_checkpoint(model2,'./test/my_model.ckpt-1720')
model2.variables

Out[72]:
[]

The original model has lots of variables in it.:

model.variables

[<tf.Variable 'mnist_model_1/conv2d/kernel:0' shape=(5, 5, 1, 32) dtype=float32, numpy=
 array([[[[ -8.25184360e-02,   6.77833706e-03,   6.97569922e-02,...
like image 494
Allen Avatar asked Dec 17 '17 05:12

Allen


4 Answers

Eager Execution is still a new feature in TensorFlow, and was not included in the latest version, so not all features, are supported, but fortunately, loading a model from a saved checkpoint is.

You'll need to use the tfe.Saver class (which is a thin wrapper over the tf.train.Saver class), and your code should look something like this:

saver = tfe.Saver([x, y])
saver.restore('/tmp/ckpt')

Where [x,y] represents the list of variables and/or models you wish to restore. This should precisely match the variables passed when the saver that created the checkpoint was initially created.

More details, including sample code, can be found here, and the API details of the saver can be found here.

like image 97
mr_snuffles Avatar answered Oct 19 '22 10:10

mr_snuffles


Ok, after spending a few hours running the code in line-by-line mode, I've figured out a way to restore a checkpoint to a new TensorFlow Eager Mode model.

Using the examples from TF Eager Mode MNIST

Steps:

  1. After your model has been trained, find the latest checkpoint(or the checkpoint you want) index file from the checkpoint folder created in the training process, such as 'ckpt-25800.index'. Use only the filename 'ckpt-25800' while restoring in step 5.

  2. Start a new python terminal and enable TensorFlow Eager mode by running:

    tfe.enable_eager_execution()

  3. Create a new instance of the MNISTMOdel:

    model_new = MNISTModel()

  4. Initialise the variables for model_new by running a dummy train process once.(This step is important. Without initialising the variables first, they can't be restored by the following step. However I can't find another way to initialise variables in Eager mode other than what I did below.)

    model_new(tfe.Variable(np.zeros((1,784),dtype=np.float32)), training=True)

  5. Restore the variables to model_new using the checkpoint identified in step 1.

    tfe.Saver((model_new.variables)).restore('./tf_checkpoints/ckpt-25800')

  6. If restore process is successful, you should see something like:

    INFO:tensorflow:Restoring parameters from ./tf_checkpoints/ckpt-25800

Now the checkpoint has been successfully restored to model_new and you can use it to make predictions on new data.

like image 43
Allen Avatar answered Oct 19 '22 10:10

Allen


I like to share TFLearn library which is Deep learning library featuring a higher-level API for TensorFlow. With the help of this library you can easily save and restore a model.

Saving a model

model = tflearn.DNN(net) #Here 'net' is your designed network model. 
#This is a sample example for training the model
model.fit(train_x, train_y, n_epoch=10, validation_set=(test_x, test_y), batch_size=10, show_metric=True)
model.save("model_name.ckpt")

Restore a model

model = tflearn.DNN(net)
model.load("model_name.ckpt")

For more example of tflearn you can check some site like...

  • My first CNN in TFLearn.
  • Github Link
like image 41
R.A.Munna Avatar answered Oct 19 '22 10:10

R.A.Munna


  • First you save your model in a checkpoint by doing following:

saver.save(sess, './my_model.ckpt')

  • In above line you are saving you session in "my_model.ckpt" checkpoint

Following code restores the model

saver = tf.train.Saver()
with tf.Session() as sess:
    saver.restore(sess, './my_model.ckpt')
  • When you restore the session as a model then you restores your model from the ckpt

For eager mode to save :

tf.contrib.eager.save_network_checkpoint(sess,'./my_model.ckpt')

For eager mode to restore :

tf.contrib.eager.restore_network_checkpoint(sess,'./my_model.ckpt')

sess is an object of class Network. Any object of class Network can be saved and restored. A quick explanation of network objects :-

class TwoLayerNetwork(tfe.Network):
    def __init__(self, name):
        super(TwoLayerNetwork, self).__init__(name=name)
        self.layer_one = self.track_layer(tf.layers.Dense(16, input_shape=(8,)))
        self.layer_two = self.track_layer(tf.layers.Dense(1, input_shape=(16,)))
    def call(self, inputs):
        return self.layer_two(self.layer_one(inputs))

After constructing an object and calling the Network, a list of variables created by tracked Layers is available via Network.variables: python

  sess = TwoLayerNetwork(name="net")   # sess is object of Network 
  output = sess(tf.ones([1, 8]))
  print([v.name for v in sess.variables])
  ```
  =================================================================
  This example prints variable names, one kernel and one bias per
  `tf.layers.Dense` layer:

  ['net/dense/kernel:0',
   'net/dense/bias:0',
   'net/dense_1/kernel:0',
   'net/dense_1/bias:0']

  These variables can be passed to a `Saver` (`tf.train.Saver`, or
  `tf.contrib.eager.Saver` when executing eagerly) to save or restore the
  `Network` 
  =================================================================
  ```
  tfe.save_network_checkpoint(sess,'./my_model.ckpt') # saving the model
  tfe.restore_network_checkpoint(sess,'./my_model.ckpt') # restoring 
like image 1
Jai Avatar answered Oct 19 '22 09:10

Jai