Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow freeze_graph script failing on model defined with Keras

Tags:

tensorflow

I'm attempting to export a model built and trained with Keras to a protobuffer that I can load in a C++ script (as in this example). I've generated a .pb file containing the model definition and a .ckpt file containing the checkpoint data. However, when I try to merge them into a single file using the freeze_graph script I get the error:

ValueError: Fetch argument 'save/restore_all' of 'save/restore_all' cannot be interpreted as a Tensor. ("The name 'save/restore_all' refers to an Operation not in the graph.")

I'm saving the model like this:

with tf.Session() as sess:
    model = nndetector.architecture.models.vgg19((3, 50, 50))
    model.load_weights('/srv/nn/weights/scratch-vgg19.h5')
    init_op = tf.initialize_all_variables()
    sess.run(init_op)
    graph_def = sess.graph.as_graph_def()
    tf.train.write_graph(graph_def=graph_def, logdir='.',   name='model.pb', as_text=False)
    saver = tf.train.Saver()
    saver.save(sess, 'model.ckpt')

nndetector.architecture.models.vgg19((3, 50, 50)) is simply a vgg19-like model defined in Keras.

I'm calling the freeze_graph script like this:

bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=[path-to-model.pb] --input_checkpoint=[path-to-model.ckpt] --output_graph=[output-path] --output_node_names=sigmoid --input_binary=True

If I run the freeze_graph_test script everything works fine.

Does anyone know what I'm doing wrong?

Thanks.

Best regards

Philip

EDIT

I've tried printing tf.train.Saver().as_saver_def().restore_op_name which returns save/restore_all.

Additionally, I've tried a simple pure tensorflow example and still get the same error:

a = tf.Variable(tf.constant(1), name='a')
b = tf.Variable(tf.constant(2), name='b')
add = tf.add(a, b, 'sum')

with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
tf.train.Saver().save(sess, 'simple.ckpt')

And I'm actually also unable to restore the graph in python. Using the following code throws ValueError: No variables to save if I execute it separately from saving the graph (that is, if I both save and restore the model in the same script, everything works fine).

with gfile.FastGFile('simple_as_binary.pb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())

with tf.Session() as sess:
    tf.import_graph_def(graph_def)
    saver = tf.train.Saver()
    saver.restore(sess, 'simple.ckpt')

I'm not sure if the two problems are related, or if I'm simply not restoring the model correctly in python.

like image 368
ppries Avatar asked Jun 08 '16 13:06

ppries


People also ask

How do I make a frozen graph in TensorFlow?

To make a frozen graph, I first create a saved model using tf.saved_model.simple_save and then freeze it using tensorflow.python.tools.freeze_graph.freeze_graph. If a model contains some tf.keras.layers.BatchNormalisation layers, freezing will fail in TF 1.14.0 with:

How do I freeze a keras model?

Key takeaways Keras models can be trained in a TensorFlow environment or, more conveniently, turned into an Estimator with little syntactic change. To freeze a model you first need to generate the checkpoint and graph files on which to can call freeze_graph.py or the simplified version above.

How do I use keras with TensorFlow?

Once you have designed a network using Keras, you may want to serve it in another API, on the web, or other medium. One of the easiest way to do many of the above is to use the pre-built TensorFlow libraries (such as the TensorFlow C++ API for model inference in a C++ environment).

What are the most common issues with TensorFlow?

There are many issues flagged on the TensorFlow as Keras GitHubs, as well as stack overflow, about freezing models, a large number of which can be resolved by understanding the files which need to be generated and how to specify the output node.


Video Answer


1 Answers

The problem is the order of these two lines in your original program:

tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
tf.train.Saver().save(sess, 'simple.ckpt')

Calling tf.train.Saver() adds a set of nodes to the graph, including one called "save/restore_all". However, this program calls it after writing out the graph, so the file you pass to freeze_graph.py doesn't contain those nodes, which are necessary for doing the rewriting.

Reversing the two lines should make the script work as intended:

tf.train.Saver().save(sess, 'simple.ckpt')
tf.train.write_graph(graph_def=sess.graph.as_graph_def(), logdir='.',     name='simple_as_binary.pb', as_text=False)
like image 109
mrry Avatar answered Oct 14 '22 02:10

mrry