Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How to use tf.reset_default_graph()

Whenever I try to use tf.reset_default_graph(), I get this error: IndexError: list index out of range or ``. At which part of my code should I use this? When should I be using this?


I updated the code, but the error still occurs.

def evaluate():
    with tf.name_scope("loss"):
        global x # x is a tf.placeholder()
        xentropy = tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=neural_network(x))
        loss = tf.reduce_mean(xentropy, name="loss")

    with tf.name_scope("train"):
        optimizer = tf.train.AdamOptimizer()
        training_op = optimizer.minimize(loss)

    with tf.name_scope("exec"):
        with tf.Session() as sess:
            for i in range(1, 2):
                sess.run(training_op, feed_dict={x: np.array(train_data).reshape([-1, 1]), y: label})
                print "Training " + str(i)
                saver = tf.train.Saver()
                saver.save(sess, "saved_models/testing")
                print "Model Saved."

def predict():
    with tf.name_scope("predict"):
        with tf.Session() as sess:
            saver = tf.train.import_meta_graph("saved_models/testing.meta")
            saver.restore(sess, "saved_models/testing")
            output_ = tf.get_default_graph().get_tensor_by_name('output_layer:0')
            print sess.run(output_, feed_dict={x: np.array([12003]).reshape([-1, 1])})

def main():
    print "Starting Program..."
    writer = tf.summary.FileWriter("mygraph/logs", tf.get_default_graph())

If I remove the tf.reset_default_graph() from the updated code, I get this error: ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used

From my current understanding, tf.reset_default_graph() removes all graphs, hence I avoided the error I mention above(ValueError: cannot add op with name hidden_layer1/kernel/Adam as that name is already used)

like image 797
Bosen Avatar asked Jul 04 '17 02:07


People also ask

What is use of TF Reset_default_graph ()?

tf. reset_default_graph() Defined in tensorflow/python/framework/ops.py . See the guide: Building Graphs > Utility functions. Clears the default graph stack and resets the global default graph.

What is TF compat v1 Reset_default_graph ()?

Clears the default graph stack and resets the global default graph. tf. compat. v1. reset_default_graph()

Video Answer

2 Answers

This is probably how you use it:

import tensorflow as tf
a = tf.constant(1)
with tf.Session() as sess:

You get an error because you use it in a session. From the tf.reset_default_graph() documentation:

Calling this function while a tf.Session or tf.InteractiveSession is active will result in undefined behavior. Using any previously created tf.Operation or tf.Tensor objects after calling this function will result in undefined behavior

tf.reset_default_graph() can be helpful (at least for me) during the testing phase while I experiment in jupyter notebook. However, I have never used it in production and do not see how it would be helpful there.

Here is an example that could be in a notebook:

import tensorflow as tf
# create some graph
with tf.Session() as sess:
    print sess.run(...)

Now I do not need this stuff anymore, but if I create another graph and visualize it in tensorboard I will see old nodes and the new nodes. To solve this, I could restart the kernel and run only the next cell. However, I can just do:

# create a new graph
with tf.Session() as sess:
    print sess.run(...)

Edit after OP added his code:

with tf.name_scope("predict"):

Here is what approximately happens. Your code fails because tf.name_scope already added something to a graph. While being inside of this "adding something to the graph", you tell TF to remove the graph completely, but it can't because it is busy adding something.

like image 101
Salvador Dali Avatar answered Oct 11 '22 02:10

Salvador Dali

For some reason, I need to build a new graph FOR LOTS OF TIMES, and I have just tested, which works eventually! Many thanks for Salvador Dali's answer:-)

import tensorflow as tf
from my_models import Classifier

for i in range(10):
    # build the graph
    global_step = tf.get_variable('global_step', [], initializer=tf.constant_initializer(0), trainable=False)
    classifier = Classifier(global_step)
    with tf.Session() as sess:
        print("do sth here.")
like image 36
庞琳卓 Avatar answered Oct 11 '22 02:10
