I'm learning Tensorflow and am trying to properly structure my code. I (more or less) know how to build graphs either bare or as class methods, but I'm trying to figure out how best to structure the code. I've tried the simple example:
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
which should just print out 4. However, when I do that, I get the error:
Cannot interpret feed_dict key as Tensor: Tensor
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.
I'm pretty sure this is because the placeholder inside the function build_graph
is private, but shouldn't the with tf.Session(graph=graph)
take care of that? Is there a better way of using a feed dict in a situation like this?
There are several options.
Option 1: just pass the name of the tensor instead of the tensor itself.
with tf.Session(graph=graph) as sess:
feed = {"Placeholder:0": 3}
print(sess.run("Add:0", feed_dict=feed))
In this case, it's probably best to give the nodes meaningful names, instead of using the default names as above:
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
with tf.Session(graph=graph) as sess:
feed = {"a:0": 3}
print(sess.run("b:0", feed_dict=feed))
Recall that the outputs of an operation named "foo"
are tensors named "foo:0"
, "foo:1"
, and so on. Most operations have just one output.
Option 2: make your build_graph()
function return all the important nodes.
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
return g, a, b
graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Option 3: add important nodes to a collection
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8)
b = tf.add(a, tf.constant(1, dtype=tf.int8))
for node in (a, b):
g.add_to_collection("important_stuff", node)
return g
graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
Option 4: as suggested by @pohe you can use get_tensor_by_name()
def build_graph():
g = tf.Graph()
with g.as_default():
a = tf.placeholder(tf.int8, name="a")
b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
return g
graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
I personally use option 2 most often, it's pretty straightforward and doesn't require playing with names. I use option 3 when the graph is large and will live for a long time, because the collection gets saved along with the model, and it's a quick way to document what really matters. I don't really use option 1, because I prefer to have actual references to objects (not sure why). Option 4 is useful when you are working with a graph built by someone else, and they didn't give you direct references to tensors.
Hope this helps!
I'm looking for a better way as well, so my answer is probably not the best. Nevertheless, if you give a
and b
a name, such as
a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')
Then you can do
graph = build_graph()
a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')
with tf.Session(graph=graph) as sess:
feed = {a: 3}
print(sess.run(b, feed_dict=feed))
p.s. naming a
and b
is not necessary. it's just easier to reference later. Also, if you've found a better solution for it, please share it too.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With