Tensorflow includes the function tf.train.ExponentialMovingAverage
which allows us to apply a moving average to the parameters, which I've found to be great to stabilize the testing of the model.
With that said, I've found it somewhat irritatingly hard to apply this to general models. My so far most successful approach (shown below) has been to write a function decorator and then put my whole NN inside a function.
This does however have several downsides. For one, it duplicates the whole graph, and second, I need to define my NN inside a function.
Is there a better way to do this?
def ema_wrapper(is_training, decay=0.99):
"""Use Exponential Moving Average of parameters during testing.
Parameters
----------
is_training : bool or `tf.Tensor` of type bool
EMA is applied if ``is_training`` is False.
decay:
Decay rate for `tf.train.ExponentialMovingAverage`
"""
def function(fun):
@functools.wraps(fun)
def fun_wrapper(*args, **kwargs):
# Regular call
with tf.variable_scope('ema_wrapper', reuse=False) as scope:
result_train = fun(*args, **kwargs)
# Set up exponential moving average
ema = tf.train.ExponentialMovingAverage(decay=decay)
var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
scope.name)
ema_op = ema.apply(var_class)
# Add to collection so they are updated
tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)
# Getter for the variables with EMA applied
def ema_getter(getter, name, *args, **kwargs):
var = getter(name, *args, **kwargs)
ema_var = ema.average(var)
return ema_var if ema_var else var
# Call with EMA applied
with tf.variable_scope('ema_wrapper', reuse=True,
custom_getter=ema_getter):
result_test = fun(*args, **kwargs)
# Return the correct version depending on if we're training or not
return tf.cond(is_training,
lambda: result_train, lambda: result_test)
return fun_wrapper
return function
Example usage:
@ema_wrapper(is_training)
def neural_network(x):
# If is_training is False, we will use an EMA of a instead
a = tf.get_variable('a', [], tf.float32)
return a * x
You can have an op that transfers the value from the EMA variables to the original ones:
import tensorflow as tf
# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Make EMA object and update interal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
train_op = ema.apply(model_vars)
# Transfer EMA values to original variables
retrieve_ema_weights_op = tf.group(
[tf.assign(var, ema.average(var)) for var in model_vars])
with tf.Session() as sess:
# Do training
while ...:
sess.run(train_op, ...)
# Copy EMA values to weights
sess.run(retrieve_ema_weights_op)
# Test model with EMA weights
# ...
EDIT:
I made a longer version with the ability to switch between train and test mode with variable backups:
import tensorflow as tf
# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
is_training = tf.get_variable('is_training', shape=(), dtype=tf.bool,
initializer=tf.constant_initializer(True, dtype=tf.bool))
# Make EMA object and update internal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
train_op = ema.apply(model_vars)
# Make backup variables
with tf.variable_scope('BackupVariables'):
backup_vars = [tf.get_variable(var.op.name, dtype=var.value().dtype, trainable=False,
initializer=var.initialized_value())
for var in model_vars]
def ema_to_weights():
return tf.group(*(tf.assign(var, ema.average(var).read_value())
for var in model_vars))
def save_weight_backups():
return tf.group(*(tf.assign(bck, var.read_value())
for var, bck in zip(model_vars, backup_vars)))
def restore_weight_backups():
return tf.group(*(tf.assign(var, bck.read_value())
for var, bck in zip(model_vars, backup_vars)))
def to_training():
with tf.control_dependencies([tf.assign(is_training, True)]):
return restore_weight_backups()
def to_testing():
with tf.control_dependencies([tf.assign(is_training, False)]):
with tf.control_dependencies([save_weight_backups()]):
return ema_to_weights()
switch_to_train_mode_op = tf.cond(is_training, lambda: tf.group(), to_training)
switch_to_test_mode_op = tf.cond(is_training, to_testing, lambda: tf.group())
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
# Unnecessary, since it begins in training mode, but unharmful
sess.run(switch_to_train_mode_op)
# Do training
while ...:
sess.run(train_op, ...)
# To test mode
sess.run(switch_to_test_mode_op)
# Switching multiple times should not overwrite backups
sess.run(switch_to_test_mode_op)
# Test model with EMA weights
# ...
# Back to training mode
sess.run(switch_to_train_mode_op)
# Keep training...
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With