I'm using Tensorflow v1.1 and I've been trying to figure out how to use my EMA'ed weights for inference, but no matter what I do I keep getting the error
Not found: Key W/ExponentialMovingAverage not found in checkpoint
even though when I loop through and print out all the tf.global_variables
the key exists
Here is a reproducible script heavily adapted from Facenet's unit test:
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# Create 100 phony x, y data points in NumPy, y = x * 0.1 + 0.3
x_data = np.random.rand(100).astype(np.float32)
y_data = x_data * 0.1 + 0.3
# Try to find values for W and b that compute y_data = W * x_data + b
# (We know that W should be 0.1 and b 0.3, but TensorFlow will
# figure that out for us.)
W = tf.Variable(tf.random_uniform([1], -1.0, 1.0), name='W')
b = tf.Variable(tf.zeros([1]), name='b')
y = W * x_data + b
# Minimize the mean squared errors.
loss = tf.reduce_mean(tf.square(y - y_data))
optimizer = tf.train.GradientDescentOptimizer(0.5)
opt_op = optimizer.minimize(loss)
# Track the moving averages of all trainable variables.
ema = tf.train.ExponentialMovingAverage(decay=0.9999)
variables = tf.trainable_variables()
print(variables)
averages_op = ema.apply(tf.trainable_variables())
with tf.control_dependencies([opt_op]):
train_op = tf.group(averages_op)
# Before starting, initialize the variables. We will 'run' this first.
init = tf.global_variables_initializer()
saver = tf.train.Saver(tf.trainable_variables())
# Launch the graph.
sess = tf.Session()
sess.run(init)
# Fit the line.
for _ in range(201):
sess.run(train_op)
w_reference = sess.run('W/ExponentialMovingAverage:0')
b_reference = sess.run('b/ExponentialMovingAverage:0')
saver.save(sess, os.path.join("model_ex1"))
tf.reset_default_graph()
tf.train.import_meta_graph("model_ex1.meta")
sess = tf.Session()
print('------------------------------------------------------')
for var in tf.global_variables():
print('all variables: ' + var.op.name)
for var in tf.trainable_variables():
print('normal variable: ' + var.op.name)
for var in tf.moving_average_variables():
print('ema variable: ' + var.op.name)
print('------------------------------------------------------')
mode = 1
restore_vars = {}
if mode == 0:
ema = tf.train.ExponentialMovingAverage(1.0)
for var in tf.trainable_variables():
print('%s: %s' % (ema.average_name(var), var.op.name))
restore_vars[ema.average_name(var)] = var
elif mode == 1:
for var in tf.trainable_variables():
ema_name = var.op.name + '/ExponentialMovingAverage'
print('%s: %s' % (ema_name, var.op.name))
restore_vars[ema_name] = var
saver = tf.train.Saver(restore_vars, name='ema_restore')
saver.restore(sess, os.path.join("model_ex1")) # error happens here!
w_restored = sess.run('W:0')
b_restored = sess.run('b:0')
print(w_reference)
print(w_restored)
print(b_reference)
print(b_restored)
The key not found in checkpoint
error means that the variable exists in your model in memory but not in the serialized checkpoint file on disk.
You should use the inspect_checkpoint tool to understand what tensors are being saved in your checkpoint, and why some exponential moving averages are not being saved here.
It's not clear from your repro example which line is supposed to trigger the error
I'd like to add a method to use the trained variables in the checkpoint at best.
Keep in mind that all variables in the saver var_list should be contained in the checkpoint you configured. You can check those in the saver by:
print(restore_vars)
and those variables in the checkpoint by:
vars_in_checkpoint = tf.train.list_variables(os.path.join("model_ex1"))
in your case.
If the restore_vars are all included in vars_in_checkpoint then it will not raise the error, otherwise initialize all variables first:
all_variables = tf.get_collection_ref(tf.GraphKeys.GLOBAL_VARIABLES)
sess.run(tf.variables_initializer(all_variables))
All variables will be initialized be those in or not in the checkpoint, then you can filter out those variables in restore_vars that are not included in the checkpoint(suppose all variable with ExponentialMovingAverage in their names are not in the checkpoint):
temp_saver = tf.train.Saver(
var_list=[v for v in all_variables if "ExponentialMovingAverage" not in v.name])
ckpt_state = tf.train.get_checkpoint_state(os.path.join("model_ex1"), lastest_filename)
print('Loading checkpoint %s' % ckpt_state.model_checkpoint_path)
temp_saver.restore(sess, ckpt_state.model_checkpoint_path)
This may save some time compared to training the model from scratch. (In my scenario the restored variables make no significant improvement compared to training from scratch in the beginning, since all old optimizer variables are abandoned. But it can accelerate the optimization process significantly, I think, because it is like pretraining some variables)
Anyway, some variables are useful to be restored like embeddings and some layers and etc.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With