Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does graph argument in tf.Session() do?

I am having trouble understanding the graph argument in the tf.Session(). I tried looking up at the TensorFlow website :link but couldn't understand much.

I am trying to find out the different between tf.Session() and tf.Session(graph=some_graph_inserted_here).

Question Context

Code A (Not Working):

def predict():
    with tf.name_scope("predict"):
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            loaded_graph = tf.get_default_graph()
            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
            _x = loaded_graph.get_tensor_by_name('x:0')
            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

This code gives the following error: ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used when trying to load the graph at saver = tf.train.import_meta_graph("saved_models/testing.meta")

Code B (Working):

def predict():
    with tf.name_scope("predict"):
        loaded_graph = tf.Graph()
        with tf.Session(graph=loaded_graph) as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
            _x = loaded_graph.get_tensor_by_name('x:0')
            print sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])})

The codes does not work if I replace loaded_graph = tf.Graph() with loaded_graph = tf.get_default_graph(). Why?

Full Code if it helps: (https://gist.github.com/duemaster/f8cf05c0923ebabae476b83e895619ab)

like image 453
Bosen Avatar asked Jul 04 '17 10:07

Bosen


1 Answers

The TensorFlow Graph is an object which contains your various tf.Tensor and tf.Operation.

When you create these tensors (e.g. using tf.Variable or tf.constant) or operations (e.g. tf.matmul), they will be added to the default graph (look at the graph member of these object to get the graph they belong to). If you haven't specified anything, it will be the graph you get when calling the tf.get_default_graph method.

But you could also work with multiple graphes using a context manager:

g = tf.Graph()
with g.as_default():
    [your code]

Suppose you created several graphes in your code, you then need to put the graph you and to run as an argument of the tf.Session method to specify TensorFlow which one to run.

In Code A, you

  • work with the default graph,
  • try to import the meta graph into it (which fails because it already contains some of the nodes) and,
  • would restore the model into it,

while in Code B, you

  • create a fresh new graph,
  • import the meta graph into it (which succeeds because it's an empty graph) and
  • restore it.

Useful link:

tf.Graph API

Edit:

This piece of code makes the Code A work (I reset the default graph to a fresh one, and I removed the predict name_scope).

def predict():
    tf.reset_default_graph()
    with tf.Session() as sess:
        saver = tf.train.import_meta_graph("saved_models/testing.meta")
        saver.restore(sess, "saved_models/testing")
        loaded_graph = tf.get_default_graph()
        output_ = loaded_graph.get_tensor_by_name('loss/network/output_layer/BiasAdd:0')
        _x = loaded_graph.get_tensor_by_name('x:0')
        print(sess.run(output_, feed_dict={_x: np.array([12003]).reshape([-1, 1])}))
like image 91
pfm Avatar answered Sep 23 '22 07:09

pfm