Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to restore session in tensorflow? [duplicate]

I want to use my neural network without training the net again. I read about

save_path = saver.save(sess, "model.ckpt")
print("Model saved in file: %s" % save_path)

and now I have 3 files in the folder: checkpoint, model.ckpt, and model.ckpt.meta

I want, in another class in python to restore the data, get the weights of my neural network and make a single prediction.

How can I do this?

like image 587
Or Perets Avatar asked Dec 08 '16 10:12

Or Perets


1 Answers

To save the model you can do like this:

model_checkpoint = 'model.chkpt'

# Create the model
...
...

with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())

    # Create a saver so we can save and load the model as we train it
    tf_saver = tf.train.Saver(tf.all_variables())

    # (Optionally) do some training of the model
    ...
    ...

    tf_saver.save(sess, model_checkpoint)

I assume you have already done this, since you have gotten three files. When you want to load the model in another class, you can do it like this:

# The same file as we saved earlier
model_checkpoint = 'model.chkpt'

# Create the SAME model as before
...
...

with tf.Session() as sess:
    # Restore the model
    tf_saver = tf.train.Saver()
    tf_saver.restore(sess, model_checkpoint)

    # Now your model is loaded with the same values as when you saved,
    #   and you can do prediction or continue training
like image 196
Ole Steinar Skrede Avatar answered Oct 07 '22 01:10

Ole Steinar Skrede