Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Given a tensor flow model graph, how to find the input node and output node names

Tags:

I use custom model for classification in Tensor flow Camera Demo. I generated a .pb file (serialized protobuf file) and I could display the huge graph it contains. To convert this graph to a optimized graph, as given in [https://www.oreilly.com/learning/tensorflow-on-android], the following procedure could be used:

$ bazel-bin/tensorflow/python/tools/optimize_for_inference  \ --input=tf_files/retrained_graph.pb \ --output=tensorflow/examples/android/assets/retrained_graph.pb --input_names=Mul \ --output_names=final_result 

Here how to find the input_names and output_names from the graph display. When I dont use proper names, I get device crash:

E/TensorFlowInferenceInterface(16821): Failed to run TensorFlow inference  with inputs:[AvgPool], outputs:[predictions]  E/AndroidRuntime(16821): FATAL EXCEPTION: inference  E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible  shapes: [1,224,224,3] vs. [32,1,1,2048]  E/AndroidRuntime(16821):     [[Node: dropout/dropout/mul = Mul[T=DT_FLOAT,  _device="/job:localhost/replica:0/task:0/cpu:0"](dropout/dropout/div,  dropout/dropout/Floor)]] 
like image 480
Santle Camilus Avatar asked Apr 20 '17 11:04

Santle Camilus


People also ask

What is a node in TensorFlow?

Each node takes zero or more tensors as inputs and produces a tensor as an output. One type of node is a constant. Like all TensorFlow constants, it takes no inputs, and it outputs a value it stores internally.

What is GraphDef in TensorFlow?

GraphDef is the proto defined here. This is the serialized version of graph. You can print, store, or restore a GraphDef in any TensorFlow frontend (Python, R, C++, Java, ...).


1 Answers

Try this:

run python

>>> import tensorflow as tf >>> gf = tf.GraphDef() >>> gf.ParseFromString(open('/your/path/to/graphname.pb','rb').read()) 

and then

>>> [n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Placeholder')] 

Then, you can get result similar to this:

['Mul=>Placeholder', 'final_result=>Softmax'] 

But I'm not sure it's the problem of node names regarding the error messages. I guess you provided wrong arguements when loading the graph file or your generated graph file is something wrong?

Check this part:

E/AndroidRuntime(16821): java.lang.IllegalArgumentException: Incompatible  shapes: [1,224,224,3] vs. [32,1,1,2048] 

UPDATE: Sorry, if you're using (re)trained graph , then try this:

[n.name + '=>' +  n.op for n in gf.node if n.op in ( 'Softmax','Mul')] 

It seems that (re)trained graph saves input/output op name as "Mul" and "Softmax", while optimized and/or quantized graph saves them as "Placeholder" and "Softmax".

BTW, using retrained graph in mobile environment is not recommended according to Peter Warden's post: https://petewarden.com/2016/09/27/tensorflow-for-mobile-poets/ . It's better to use quantized or memmapped graph due to performance and file size issue, I couldn't find out how to load memmapped graph in android though...:( (no problem loading optimized / quantized graph in android)

like image 148
JP Kim Avatar answered Sep 22 '22 08:09

JP Kim