Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

get_variable() does not work after session restoration

Tags:

tensorflow

I try to restore a session and call get_variable() to get an object of type tf.Variable (according to this answer). And it fails to find the variable. The minimal example to reproduce the case is as follows.

First, create a variable and save the session.

import tensorflow as tf

var = tf.Variable(101)

with tf.Session() as sess:
    with tf.variable_scope(''):
        scoped_var = tf.get_variable('scoped_var', [])

    with tf.variable_scope('', reuse=True):
        new_scoped_var = tf.get_variable('scoped_var', [])

    assert scoped_var is new_scoped_var
    sess.run(tf.global_variables_initializer())
    saver = tf.train.Saver()
    print(sess.run(scoped_var))
    saver.save(sess, 'data/sess')

Here get_variables inside a scope with reuse=True works fine. Then, restore the session from a file and try to get the variable.

import tensorflow as tf

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('data/sess.meta')
    saver.restore(sess, 'data/sess')

    for v in tf.get_collection('variables'):
        print(v.name)

    print(tf.get_collection(("__variable_store",)))
    # Oops, it's empty!

    with tf.variable_scope('', reuse=True):
        # the next line fails
        new_scoped_var = tf.get_variable('scoped_var', [])

    print("new_scoped_var: ", new_scoped_var)

Output:

Variable:0
scoped_var:0
[]
Traceback (most recent call last):
...
ValueError: Variable scoped_var does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=None in VarScope?

As we can see, get_variable() can not find the variable. And ("__variable_store",) collection, that is used internally by get_variable(), is empty.

Why does get_variable fail?

like image 983
Alexander Lobov Avatar asked Nov 08 '22 00:11

Alexander Lobov


1 Answers

Instead of dealing with the meta graph (which can be helpful if you want to modify the graph and how it's loaded etc) you can try this.

import tensorflow as tf

with tf.Session() as sess:
  with tf.variable_scope(''):
    scoped_var = tf.get_variable('scoped_var', [])

  with tf.variable_scope('', reuse=True):
    new_scoped_var = tf.get_variable('scoped_var', [])

  assert scoped_var is new_scoped_var
  saver = tf.train.Saver()
  path = tf.train.get_checkpoint_state('data/sess')
  if path is not None:
    saver.restore(sess, path.model_checkpoint_path)
  else:
    sess.run(tf.global_variables_initializer())

  print(sess.run(scoped_var))
  saver.save(sess, 'data/sess')

  #now continue to use as you normally would with a restored model

The main difference is you've set up your model before calling saver.restore

like image 150
Steven Avatar answered Nov 15 '22 07:11

Steven