Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to save/restore a model after training?

After you train a model in Tensorflow:

  1. How do you save the trained model?
  2. How do you later restore this saved model?
like image 590
mathetes Avatar asked Nov 17 '15 14:11

mathetes


People also ask

How do I save a model after training TensorFlow?

Using save_weights() method Now you can simply save the weights of all the layers using the save_weights() method. It saves the weights of the layers contained in the model. It is advised to use the save() method to save h5 models instead of save_weights() method for saving a model using tensorflow.

How do you save a keras model after training?

There are two formats you can use to save an entire model to disk: the TensorFlow SavedModel format, and the older Keras H5 format. The recommended format is SavedModel. It is the default when you use model. save() .


2 Answers

I am improving my answer to add more details for saving and restoring models.

In(and after) Tensorflow version 0.11:

Save the model:

import tensorflow as tf  #Prepare to feed input, i.e. feed_dict and placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") b1= tf.Variable(2.0,name="bias") feed_dict ={w1:4,w2:8}  #Define a test operation that we will restore w3 = tf.add(w1,w2) w4 = tf.multiply(w3,b1,name="op_to_restore") sess = tf.Session() sess.run(tf.global_variables_initializer())  #Create a saver object which will save all the variables saver = tf.train.Saver()  #Run the operation by feeding input print sess.run(w4,feed_dict) #Prints 24 which is sum of (w1+w2)*b1   #Now, save the graph saver.save(sess, 'my_test_model',global_step=1000) 

Restore the model:

import tensorflow as tf  sess=tf.Session()     #First let's load meta graph and restore weights saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess,tf.train.latest_checkpoint('./'))   # Access saved Variables directly print(sess.run('bias:0')) # This will print 2, which is the value of bias that we saved   # Now, let's access and create placeholders variables and # create feed-dict to feed new data  graph = tf.get_default_graph() w1 = graph.get_tensor_by_name("w1:0") w2 = graph.get_tensor_by_name("w2:0") feed_dict ={w1:13.0,w2:17.0}  #Now, access the op that you want to run.  op_to_restore = graph.get_tensor_by_name("op_to_restore:0")  print sess.run(op_to_restore,feed_dict) #This will print 60 which is calculated  

This and some more advanced use-cases have been explained very well here.

A quick complete tutorial to save and restore Tensorflow models

like image 197
sankit Avatar answered Sep 22 '22 22:09

sankit


In (and after) TensorFlow version 0.11.0RC1, you can save and restore your model directly by calling tf.train.export_meta_graph and tf.train.import_meta_graph according to https://www.tensorflow.org/programmers_guide/meta_graph.

Save the model

w1 = tf.Variable(tf.truncated_normal(shape=[10]), name='w1') w2 = tf.Variable(tf.truncated_normal(shape=[20]), name='w2') tf.add_to_collection('vars', w1) tf.add_to_collection('vars', w2) saver = tf.train.Saver() sess = tf.Session() sess.run(tf.global_variables_initializer()) saver.save(sess, 'my-model') # `save` method will call `export_meta_graph` implicitly. # you will get saved graph files:my-model.meta 

Restore the model

sess = tf.Session() new_saver = tf.train.import_meta_graph('my-model.meta') new_saver.restore(sess, tf.train.latest_checkpoint('./')) all_vars = tf.get_collection('vars') for v in all_vars:     v_ = sess.run(v)     print(v_) 
like image 45
lei du Avatar answered Sep 19 '22 22:09

lei du