I try to restore a session and call get_variable()
to get an object of type
tf.Variable (according to this answer).
And it fails to find the variable. The minimal example to reproduce the case is
as follows.
First, create a variable and save the session.
import tensorflow as tf
var = tf.Variable(101)
with tf.Session() as sess:
with tf.variable_scope(''):
scoped_var = tf.get_variable('scoped_var', [])
with tf.variable_scope('', reuse=True):
new_scoped_var = tf.get_variable('scoped_var', [])
assert scoped_var is new_scoped_var
sess.run(tf.global_variables_initializer())
saver = tf.train.Saver()
print(sess.run(scoped_var))
saver.save(sess, 'data/sess')
Here get_variables
inside a scope with reuse=True
works fine.
Then, restore the session from a file and try to get the variable.
import tensorflow as tf
with tf.Session() as sess:
saver = tf.train.import_meta_graph('data/sess.meta')
saver.restore(sess, 'data/sess')
for v in tf.get_collection('variables'):
print(v.name)
print(tf.get_collection(("__variable_store",)))
# Oops, it's empty!
with tf.variable_scope('', reuse=True):
# the next line fails
new_scoped_var = tf.get_variable('scoped_var', [])
print("new_scoped_var: ", new_scoped_var)
Output:
Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?
As we can see, get_variable()
can not find the variable. And
("__variable_store",)
collection, that is used internally by get_variable()
,
is empty.
Why does get_variable
fail?
Instead of dealing with the meta graph (which can be helpful if you want to modify the graph and how it's loaded etc) you can try this.
import tensorflow as tf
with tf.Session() as sess:
with tf.variable_scope(''):
scoped_var = tf.get_variable('scoped_var', [])
with tf.variable_scope('', reuse=True):
new_scoped_var = tf.get_variable('scoped_var', [])
assert scoped_var is new_scoped_var
saver = tf.train.Saver()
path = tf.train.get_checkpoint_state('data/sess')
if path is not None:
saver.restore(sess, path.model_checkpoint_path)
else:
sess.run(tf.global_variables_initializer())
print(sess.run(scoped_var))
saver.save(sess, 'data/sess')
#now continue to use as you normally would with a restored model
The main difference is you've set up your model before calling saver.restore
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With