I am currently trying to export a trained TensorFlow model as a ProtoBuf file to use it with the TensorFlow C++ API on Android. Therefore, I'm using the freeze_graph.py
script.
I exported my model using tf.train.write_graph
:
tf.train.write_graph(graph_def, FLAGS.save_path, out_name, as_text=True)
and I'm using a checkpoint saved with tf.train.Saver
.
I invoke freeze_graph.py
as described at the top of the script. After compiling, I run
bazel-bin/tensorflow/python/tools/freeze_graph \
--input_graph=<path_to_protobuf_file> \
--input_checkpoint=<model_name>.ckpt-10000 \
--output_graph=<output_protobuf_file_path> \
--output_node_names=dropout/mul_1
This gives me the following error message:
TypeError: Cannot interpret feed_dict key as Tensor: The name 'save/Const:0' refers to a Tensor which does not exist. The operation, 'save/Const', does not exist in the graph.
As the error states I do not have a tensor save/Const:0
in my exported model. However, the code of freeze_graph.py
says that one can specify this tensor name by the flag filename_tensor_name
. Unfortunately I cannot find any information on what this tensor should be and how to set it correctly for my model.
Can somebody tell my either how to produce a save/Const:0
tensor in my exported ProtoBuf model or how to set the flag filename_tensor_name
correctly?
The --filename_tensor_name
flag is used to specify the name of a placeholder tensor created when you construct a tf.train.Saver
for your model.*
In your original program, you can print out the value of saver.saver_def.filename_tensor_name
to get the value that you should pass for this flag. You may also want to print the value of saver.saver_def.restore_op_name
to get a value for the --restore_op_name
flag (since I suspect the default won't be correct for your graph).
Alternatively, the tf.train.SaverDef
protocol buffer includes all of the information you need to reconstruct the relevant information for these flags. If you prefer, you can write saver.saver_def
to a file, and pass the name of that file as the --input_saver
flag to freeze_graph.py
.
* The default name scope for a tf.train.Saver
is "save/"
and the placeholder is actually a tf.constant()
whose name defaults to "Const:0"
, which explains why the flag defaults to "save/Const:0"
.
I noticed that error happened to me when I had code arranged like this:
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
init = tf.initialize_all_variables()
saver = tf.train.Saver()
sess.run(init)
It worked after I changed code layout like this:
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
init = tf.initialize_all_variables()
sess = tf.Session()
tf.train.write_graph(sess.graph_def, '', '/tmp/train.pbtxt')
sess.run(init)
I'm not really sure why is that. @mrry could you explain it a bit more?
Some follow-up on @Drag0's answer and why the new code layout fixed the error.
When calling saver = tf.train.Saver()
, you add the different variables related to the tf.train.Saver()
such as 'save/Const:0'
to the default graph.
In the first code arrangement the graph is saved before so without the tf.train.Saver()
variables. In the second code arrangement it is saved after, so the operation save/Const
will exist in the graph.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With