Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Building Tensorflow Graphs Inside of Functions

I'm learning Tensorflow and am trying to properly structure my code. I (more or less) know how to build graphs either bare or as class methods, but I'm trying to figure out how best to structure the code. I've tried the simple example:

def build_graph():                
     g = tf.Graph()     
     with g.as_default():                       
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g   

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}      
     print(sess.run(b, feed_dict=feed))

which should just print out 4. However, when I do that, I get the error:

Cannot interpret feed_dict key as Tensor: Tensor 
Tensor("Placeholder:0", dtype=int8) is not an element of this graph.

I'm pretty sure this is because the placeholder inside the function build_graph is private, but shouldn't the with tf.Session(graph=graph) take care of that? Is there a better way of using a feed dict in a situation like this?

like image 774
Alex Avatar asked Jun 07 '17 16:06

Alex


2 Answers

There are several options.

Option 1: just pass the name of the tensor instead of the tensor itself.

with tf.Session(graph=graph) as sess:
    feed = {"Placeholder:0": 3}      
    print(sess.run("Add:0", feed_dict=feed))

In this case, it's probably best to give the nodes meaningful names, instead of using the default names as above:

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {"a:0": 3}
     print(sess.run("b:0", feed_dict=feed))

Recall that the outputs of an operation named "foo" are tensors named "foo:0", "foo:1", and so on. Most operations have just one output.

Option 2: make your build_graph() function return all the important nodes.

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     return g, a, b

graph, a, b = build_graph()
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

Option 3: add important nodes to a collection

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8)
         b = tf.add(a, tf.constant(1, dtype=tf.int8))
     for node in (a, b):
         g.add_to_collection("important_stuff", node)
     return g

graph = build_graph()
a, b = graph.get_collection("important_stuff")
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

Option 4: as suggested by @pohe you can use get_tensor_by_name()

def build_graph():
     g = tf.Graph()
     with g.as_default():
         a = tf.placeholder(tf.int8, name="a")
         b = tf.add(a, tf.constant(1, dtype=tf.int8), name="b")
     return g

graph = build_graph()
a, b = [graph.get_tensor_by_name(name) for name in ("a:0", "b:0")]
with tf.Session(graph=graph) as sess:
     feed = {a: 3}
     print(sess.run(b, feed_dict=feed))

I personally use option 2 most often, it's pretty straightforward and doesn't require playing with names. I use option 3 when the graph is large and will live for a long time, because the collection gets saved along with the model, and it's a quick way to document what really matters. I don't really use option 1, because I prefer to have actual references to objects (not sure why). Option 4 is useful when you are working with a graph built by someone else, and they didn't give you direct references to tensors.

Hope this helps!

like image 141
MiniQuark Avatar answered Sep 19 '22 05:09

MiniQuark


I'm looking for a better way as well, so my answer is probably not the best. Nevertheless, if you give a and b a name, such as

a = tf.placeholder(tf.int8, name='a')
b = tf.add(a, tf.constant(1, dtype=tf.int8), name='b')

Then you can do

graph = build_graph()

a = graph.get_tensor_by_name('a:0')
b = graph.get_tensor_by_name('b:0')

with tf.Session(graph=graph) as sess:
    feed = {a: 3}      
    print(sess.run(b, feed_dict=feed))

p.s. naming a and b is not necessary. it's just easier to reference later. Also, if you've found a better solution for it, please share it too.

like image 41
pohe Avatar answered Sep 19 '22 05:09

pohe