Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Tensorflow: saving and resoring session - multiple variables

Given the following code:

import tensorflow as tf

with tf.Session() as sess:
    var = tf.Variable(42, name='var')
    sess.run(tf.global_variables_initializer())
    tf.train.export_meta_graph('file.meta')

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('file.meta')
    print sess.run(var)

I get an error at the line saver = tf.train.import_meta_graph('file.meta') saying ValueError: At least two variables have the same name: var.

How can I get around this? Is there anyway to overwrite the compute graph when importing a metagraph?

EDIT:

I have arrived at the following code:

import tensorflow as tf

file_name = "./file"

with tf.Session() as sess:
    var = tf.Variable(42, name='my_var')
    sess.run(tf.global_variables_initializer())

    saver = tf.train.Saver()
    saver.save(sess,file_name)
    saver.export_meta_graph(file_name + '.meta')

with tf.Session() as sess:
    saver = tf.train.import_meta_graph(file_name + '.meta')
    saver.restore(sess, file_name)
    print(sess.run(var))

    # new code that fails:
    saver = tf.train.Saver()
    saver.save(sess,file_name)
    saver.export_meta_graph(file_name + '.meta')

This prints the correct value for var, but when I come to save the graph a second time, I get the same original error: ValueError: At least two variables have the same name: var

like image 643
Tom Avatar asked Dec 19 '25 13:12

Tom


1 Answers

In this case you are loading the variable in the default graph where you already have the variable defined. Hence you will need to reset the TensorFlow graph before importing it.

Do this using tf.reset_default_graph(). before you import. Have a look at the "Import within the default graph" section under Exporting and Importing a MetaGraph.

Of course, you will have to redefine the variable var using tf.get_variable(). Try out this code,

import tensorflow as tf

with tf.Session() as sess:
    var = tf.Variable(42, name='var')
    sess.run(tf.global_variables_initializer())
    tf.train.export_meta_graph('file.meta')
tf.reset_default_graph()
with tf.Session() as sess:
    saver = tf.train.import_meta_graph('file.meta')
    var = tf.global_variables()[0]
    sess.run(tf.initialize_all_variables())
    print sess.run(var)

The reason your intermediate code wasn't working was that tf.get_variable() was creating a new variable which was being randomly initialized. Make sure you first do tf.get_variable_scope().reuse_variables() first. Have a look at Understanding tf.get_variable().

Unfortunately, your variable created using tf.Variable() cannot be reused with tf.get_variable() directly. Have a look at this comment and this comment to know exactly why. Hence you will need to use tf.get_variable() to create the variable if you wish to reuse it in the future.

like image 56
martianwars Avatar answered Dec 21 '25 03:12

martianwars



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!