It seems that tf.train.init_from_checkpoint
initalizes variables created via tf.get_variable
but not those created via tf.Variable
.
For example, let's create two variables and save them:
import tensorflow as tf
tf.Variable(1.0, name='foo')
tf.get_variable('bar',initializer=1.0)
saver = tf.train.Saver()
with tf.Session() as sess:
tf.global_variables_initializer().run()
saver.save(sess, './model', global_step=0)
If I load them again via a tf.train.Saver
, everything works fine: variables are loaded back to 1 even though they are initialized at zero here:
import tensorflow as tf
foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
saver = tf.train.Saver()
with tf.Session() as sess:
saver.restore(sess, './model-0')
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 1.0 bar: 1.0
However if I use tf.train.init_from_checkpoint
I get
import tensorflow as tf
foo = tf.Variable(0.0, name='foo')
bar = tf.get_variable('bar', initializer=0.0)
tf.train.init_from_checkpoint('./model-0', {'/':'/'})
with tf.Session() as sess:
tf.global_variables_initializer().run()
print(f'foo: {foo.eval()} bar: {bar.eval()}')
# foo: 0.0 bar: 1.0
bar
is set back to 1 as expected but foo
remains at 0.
Is this the intended behavior? If so, why?
Yes, this is intended. This behaviour is described in _init_from_checkpoint
method, which iterates over assignment map when loading variables to restore.
for tensor_name_in_ckpt, current_var_or_name in sorted(
six.iteritems(assignment_map)):
var = None
It first sets variable it's going to restore to None
and will reset in to current variable name if one of several conditions is met. In this particular case, loop contains statement
if "/" in current_var_or_name
So, it will load variables from a dictionary store_vars
, created earlier. It was created right after _init_from_checkpoint
checks whether current variable from assignment map is tf.Variable
, which is False at this time.
if _is_variable(current_var_or_name) or (
isinstance(current_var_or_name, list)
and all(_is_variable(v) for v in current_var_or_name)):
var = current_var_or_name
else:
store_vars = vs._get_default_variable_store()._vars
store_vars
is created by internal class _VariableStore
, more precisely, by it's _get_default_variable_store()
method. This class uses get_variable
as variable constructor. Because of the fact that tf.Variable
doesn't have default scope, and tf.get_variable
first calls tf.get_variable_scope(), which returns the current variable scope. 'foo' is outside of this scope. Besides tf.Variable
will create a new variable every time it is called and doesn't allow sharing.
store_vars
is constructed from default scope members and therefore, it contains only 'bar' variable, and foo
is added to variables collection later with tf.Variable
op.
However, if assignment_map
will contain {'foo':foo, 'bar':bar}
, the abovementioned for _init_from_checkpoint
will find these variables and load them. So in this case your code will ouput foo: 1.0 bar: 1.0
You can find code in https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/training/checkpoint_utils.py
and https://github.com/tensorflow/tensorflow/blob/r1.13/tensorflow/python/ops/variable_scope.py Also see this answer What is the default variable_scope in Tensorflow?
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With