I have a tensorflow model for which I have the .meta and the checkpoint files. I am trying to print all the placeholders that the model requires, without looking at the code that constructed the model, so that I can construct a input feed_dict without knowing how the model was created. For reference, here is the model construction code (in another file)
def save():
import tensorflow as tf
v1 = tf.placeholder(tf.float32, name="v1")
v2 = tf.placeholder(tf.float32, name="v2")
v3 = tf.multiply(v1, v2)
vx = tf.Variable(10.0, name="vx")
v4 = tf.add(v3, vx, name="v4")
saver = tf.train.Saver()
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(vx.assign(tf.add(vx, vx)))
result = sess.run(v4, feed_dict={v1:12.0, v2:3.3})
print(result)
saver.save(sess, "./model_ex1")
Now in another file, I have the following code to restore
def restore():
import tensorflow as tf
saver = tf.train.import_meta_graph("./model_ex1.meta")
print(tf.get_default_graph().get_all_collection_keys())
for v in tf.get_default_graph().get_collection("variables"):
print(v)
for v in tf.get_default_graph().get_collection("trainable_variables"):
print(v)
sess = tf.Session()
saver.restore(sess, "./model_ex1")
result = sess.run("v4:0", feed_dict={"v1:0": 12.0, "v2:0": 4.0})
print(result)
However, when I print all the variables as above, I do not see "v1:0" and "v2:0" as variable names anywhere. How to identify what tensor names placeholders had without looking at the code for creating the model?
Think of Variable in tensorflow as a normal variables which we use in programming languages. We initialize variables, we can modify it later as well. Whereas placeholder doesn't require initial value. Placeholder simply allocates block of memory for future use.
A placeholder is simply a variable that we will assign data to at a later date. It allows us to create our operations and build our computation graph, without needing the data. In TensorFlow terminology, we then feed data into the graph through these placeholders.
Saving a Tensorflow model: After the training is done, we want to save all the variables and network graph to a file for future use. So, in Tensorflow, you want to save the graph and values of all the parameters for which we shall be creating an instance of tf. train. Saver() class.
mrry's answer is great. The second solution really helps. But the op name of the Placeholder changes in different TensorFlow versions. Here is my way to find out the correct placeholder op name in the Graphdef part of the .meta
file:
saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
imported_graph = tf.get_default_graph()
graph_op = imported_graph.get_operations()
with open('output.txt', 'w') as f:
for i in graph_op:
f.write(str(i))
In the output.txt
file, we can easily find out the placeholder's correct op names and other attrs. Here is part of my output file:
name: "input/input_image"
op: "Placeholder"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "shape"
value {
shape {
dim {
size: -1
}
dim {
size: 112
}
dim {
size: 112
}
dim {
size: 3
}
}
}
}
Obviously, in my tensorflow version(1.6), the correct placeholder op name is Placeholder
. Now return back to mrry's solution. Use [x for x in tf.get_default_graph().get_operations() if x.type == "Placeholder"]
to get a list of all the placeholder ops.
Thus it's easy and convenient to perform the inference operation with only the ckpt files without needing to reconstruct the model. For example:
input_x = ... # prepare the model input
saver = tf.train.import_meta_graph('some_path/model.ckpt.meta')
graph_x = tf.get_default_graph().get_tensor_by_name('input/input_image:0')
graph_y = tf.get_default_graph().get_tensor_by_name('layer19/softmax:0')
sess = tf.Session()
saver.restore(sess, 'some_path/model.ckpt')
output_y = sess.run(graph_y, feed_dict={graph_x: input_x})
The tensors v1:0
and v2:0
were created from tf.placeholder()
ops, whereas only tf.Variable
objects are added to the "variables"
(or "trainable_variables"
) collections. There is no general collection to which tf.placeholder()
ops are added, so your options are:
Add the tf.placeholder()
ops to a collection (using tf.add_to_collection()
when constructing the original graph. You might need to add more metadata in order to suggest how the placeholders should be used.
Use [x for x in tf.get_default_graph().get_operations() if x.type == "PlaceholderV2"]
to get a list of placeholder ops after you import the metagraph.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With