I know that there are countless questions on stack and github, etc. on how to restore a trained model in Tensorflow. I have read most of them (1,2,3).
I have almost exactly the same problem as 3 however I would like if possible to solve it in a different fashion as my training and my test need to be in separate scripts called from the shell and I do not want to add the exact same lines I used to define the graph in the test script so I cannot use tensorflow FLAGS and the other answers based on reruning the graph by hand.
I also do not want to sess.run every variables and manually map them by hands as it was explained as my graph is quite big (Using import_graph_def with the arguments input_map).
So I run some graph and train it in a specific script. Like for instance (but without the training part)
#Script 1
import tensorflow as tf
import cPickle as pickle
x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
pickle.dump(graph_def,output,HIGHEST_PROTOCOL)
#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")
I now have both graph and variables saved so I should be able to run my test model from another script even if I have extra training nodes in my graph.
#Script 2
import tensorflow as tf
import cPickle as pickle
sess=tf.Session()
with open('graph.pkl','rb') as input:
graph_def=pickle.load(input)
tf.import_graph_def(graph_def,name='persisted')
Then obviously I want to restore the variables using a saver but I encounter the same problem as 3 as there are no variables found to save to even create a saver. So I cannot write:
saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")
Is there a way to bypass those limitations ? I thought by importing graph it would recover the uninitialized variables in every node but it seems not. Do I really need to rerun it a second time like most of the answers given ?
The list of variables is saved in a Collection
which is not saved in the GraphDef
. Saver
by default uses the list in ops.GraphKeys.VARIABLES
collection (accessible through tf.all_variables()
), and if you restored from GraphDef
rather than using Python API to build your model, that collection is empty. You could specify the list of variables manually in tf.train.Saver(var_list=['MyVariable1:0', 'MyVariable2:0',...])
.
Alternatively instead of GraphDef
you could use MetaGraphDef
which saves collections, there's a recently added MetaGraphDef HowTo
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With