Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Restoring graph in tensorflow fails because there is no variable to save

I know that there are countless questions on stack and github, etc. on how to restore a trained model in Tensorflow. I have read most of them (1,2,3).

I have almost exactly the same problem as 3 however I would like if possible to solve it in a different fashion as my training and my test need to be in separate scripts called from the shell and I do not want to add the exact same lines I used to define the graph in the test script so I cannot use tensorflow FLAGS and the other answers based on reruning the graph by hand.

I also do not want to sess.run every variables and manually map them by hands as it was explained as my graph is quite big (Using import_graph_def with the arguments input_map).

So I run some graph and train it in a specific script. Like for instance (but without the training part)

#Script 1
import tensorflow as tf
import cPickle as pickle

x=tf.Variable(42)
saver=tf.train.Saver()
sess=tf.Session()
#Saving the graph
graph_def=sess.graph_def
with open('graph.pkl','wb') as output:
  pickle.dump(graph_def,output,HIGHEST_PROTOCOL)


#Training the model
sess.run(tf.initialize_all_variables())
#Saving the variables
saver.save(sess,"pretrained_model.ckpt")

I now have both graph and variables saved so I should be able to run my test model from another script even if I have extra training nodes in my graph.

#Script 2
import tensorflow as tf
import cPickle as pickle

sess=tf.Session()
with open('graph.pkl','rb') as input:
  graph_def=pickle.load(input)


tf.import_graph_def(graph_def,name='persisted')

Then obviously I want to restore the variables using a saver but I encounter the same problem as 3 as there are no variables found to save to even create a saver. So I cannot write:

saver=tf.train.Saver()
saver.restore(sess,"pretrained_model.ckpt")

Is there a way to bypass those limitations ? I thought by importing graph it would recover the uninitialized variables in every node but it seems not. Do I really need to rerun it a second time like most of the answers given ?

like image 459
jeandut Avatar asked May 04 '16 09:05

jeandut


1 Answers

The list of variables is saved in a Collection which is not saved in the GraphDef. Saver by default uses the list in ops.GraphKeys.VARIABLES collection (accessible through tf.all_variables()), and if you restored from GraphDef rather than using Python API to build your model, that collection is empty. You could specify the list of variables manually in tf.train.Saver(var_list=['MyVariable1:0', 'MyVariable2:0',...]).

Alternatively instead of GraphDef you could use MetaGraphDef which saves collections, there's a recently added MetaGraphDef HowTo

like image 119
Yaroslav Bulatov Avatar answered Oct 21 '22 19:10

Yaroslav Bulatov