Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

tf.train.init_from_checkpoint does not initialize variables created with tf.Variable

It seems that tf.train.init_from_checkpoint initalizes variables created via tf.get_variable but not those created via tf.Variable.

For example, let's create two variables and save them:

import tensorflow as tf

tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  saver.save(sess, './model', global_step=0)

If I load them again via a tf.train.Saver, everything works fine: variables are loaded back to 1 even though they are initialized at zero here:

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
  saver.restore(sess, './model-0')
  print(f'foo: {foo.eval()}  bar: {bar.eval()}')
  # foo: 1.0  bar: 1.0

However if I use tf.train.init_from_checkpoint I get

import tensorflow as tf

foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
  tf.global_variables_initializer().run()
  print(f'foo: {foo.eval()}  bar: {bar.eval()}')
  # foo: 0.0  bar: 1.0

bar is set back to 1 as expected but foo remains at 0.

Is this the intended behavior? If so, why?

like image 529
user209974 Avatar asked Feb 27 '19 12:02

user209974


1 Answers

Yes, this is intended. This behaviour is described in _init_from_checkpoint method, which iterates over assignment map when loading variables to restore.

 for tensor_name_in_ckpt, current_var_or_name in sorted(
      six.iteritems(assignment_map)):
    var = None

It first sets variable it's going to restore to None and will reset in to current variable name if one of several conditions is met. In this particular case, loop contains statement

if "/" in current_var_or_name

So, it will load variables from a dictionary store_vars, created earlier. It was created right after _init_from_checkpoint checks whether current variable from assignment map is tf.Variable, which is False at this time.

 if _is_variable(current_var_or_name) or (
        isinstance(current_var_or_name, list)
        and all(_is_variable(v) for v in current_var_or_name)):
      var = current_var_or_name
    else:
      store_vars = vs._get_default_variable_store()._vars 

store_vars is created by internal class _VariableStore, more precisely, by it's _get_default_variable_store() method. This class uses get_variable as variable constructor. Because of the fact that tf.Variable doesn't have default scope, and tf.get_variable first calls tf.get_variable_scope(), which returns the current variable scope. 'foo' is outside of this scope. Besides tf.Variable will create a new variable every time it is called and doesn't allow sharing.

store_vars is constructed from default scope members and therefore, it contains only 'bar' variable, and foo is added to variables collection later with tf.Variable op.

However, if assignment_map will contain {'foo':foo, 'bar':bar}, the abovementioned for _init_from_checkpoint will find these variables and load them. So in this case your code will ouput foo: 1.0 bar: 1.0

You can find code in https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py

and https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py Also see this answer What is the default variable_scope in Tensorflow?

like image 136
Sharky Avatar answered Sep 21 '22 09:09

Sharky