Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow print all placeholder variable names from meta graph

Tags:

tensorflow

I have a tensorflow model for which I have the .meta and the checkpoint files. I am trying to print all the placeholders that the model requires, without looking at the code that constructed the model, so that I can construct a input feed_dict without knowing how the model was created. For reference, here is the model construction code (in another file)

def save():
    import tensorflow as tf
    v1 = tf.placeholder(tf.float32, name="v1") 
    v2 = tf.placeholder(tf.float32, name="v2")
    v3 = tf.multiply(v1, v2)
    vx = tf.Variable(10.0, name="vx")
    v4 = tf.add(v3, vx, name="v4")
    saver = tf.train.Saver()
    sess = tf.Session()
    sess.run(tf.initialize_all_variables())
    sess.run(vx.assign(tf.add(vx, vx)))
    result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
    print(result)
    saver.save(sess, "./model_ex1")

Now in another file, I have the following code to restore

def restore():
    import tensorflow as tf
    saver = tf.train.import_meta_graph("./model_ex1.meta")
    print(tf.get_default_graph().get_all_collection_keys())
    for v in tf.get_default_graph().get_collection("variables"):
        print(v)
    for v in tf.get_default_graph().get_collection("trainable_variables"):
        print(v)
    sess = tf.Session()
    saver.restore(sess, "./model_ex1")
    result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 4.0})
    print(result)

However, when I print all the variables as above, I do not see "v1:0" and "v2:0" as variable names anywhere. How to identify what tensor names placeholders had without looking at the code for creating the model?

like image 926
shyamupa Avatar asked May 29 '17 23:05

shyamupa


People also ask

What is the difference between placeholder and variable in TensorFlow?

Think of Variable in tensorflow as a normal variables which we use in programming languages. We initialize variables, we can modify it later as well. Whereas placeholder doesn't require initial value. Placeholder simply allocates block of memory for future use.

What does placeholder do in TensorFlow?

A placeholder is simply a variable that we will assign data to at a later date. It allows us to create our operations and build our computation graph, without needing the data. In TensorFlow terminology, we then feed data into the graph through these placeholders.

How do you save a graph in TensorFlow?

Saving a Tensorflow model: After the training is done, we want to save all the variables and network graph to a file for future use. So, in Tensorflow, you want to save the graph and values of all the parameters for which we shall be creating an instance of tf. train. Saver() class.


2 Answers

mrry's answer is great. The second solution really helps. But the op name of the Placeholder changes in different TensorFlow versions. Here is my way to find out the correct placeholder op name in the Graphdef part of the .meta file:

saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
imported_graph = tf.get_default_graph()
graph_op = imported_graph.get_operations()
with open('output.txt', 'w') as f:
    for i in graph_op:
        f.write(str(i))

In the output.txt file, we can easily find out the placeholder's correct op names and other attrs. Here is part of my output file:

name: "input/input_image"
op: "Placeholder"
attr {
  key: "dtype"
  value {
    type: DT_FLOAT
  }
}
attr {
  key: "shape"
  value {
    shape {
      dim {
        size: -1
      }
      dim {
        size: 112
      }
      dim {
        size: 112
      }
      dim {
        size: 3
      }
    }
  }
}

Obviously, in my tensorflow version(1.6), the correct placeholder op name is Placeholder. Now return back to mrry's solution. Use [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"] to get a list of all the placeholder ops.

Thus it's easy and convenient to perform the inference operation with only the ckpt files without needing to reconstruct the model. For example:

input_x = ... # prepare the model input

saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
graph_x = tf.get_default_graph().get_tensor_by_name('input/input_image:0')
graph_y = tf.get_default_graph().get_tensor_by_name('layer19/softmax:0')
sess = tf.Session()
saver.restore(sess, 'some_path/model.ckpt')

output_y = sess.run(graph_y, feed_dict={graph_x: input_x})
like image 90
ChenYang Avatar answered Nov 02 '22 11:11

ChenYang


The tensors v1:0 and v2:0 were created from tf.placeholder() ops, whereas only tf.Variable objects are added to the "variables" (or "trainable_variables") collections. There is no general collection to which tf.placeholder() ops are added, so your options are:

  1. Add the tf.placeholder() ops to a collection (using tf.add_to_collection() when constructing the original graph. You might need to add more metadata in order to suggest how the placeholders should be used.

  2. Use [x for x in tf.get_default_graph().get_operations() if x.type == "PlaceholderV2"] to get a list of placeholder ops after you import the metagraph.

like image 40
mrry Avatar answered Nov 02 '22 12:11

mrry