Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use Exponential Moving Average in Tensorflow

The problem

Tensorflow includes the function tf.train.ExponentialMovingAverage which allows us to apply a moving average to the parameters, which I've found to be great to stabilize the testing of the model.

With that said, I've found it somewhat irritatingly hard to apply this to general models. My so far most successful approach (shown below) has been to write a function decorator and then put my whole NN inside a function.

This does however have several downsides. For one, it duplicates the whole graph, and second, I need to define my NN inside a function.

Is there a better way to do this?

Current Implementation

def ema_wrapper(is_training, decay=0.99):
    """Use Exponential Moving Average of parameters during testing.

    Parameters
    ----------
    is_training : bool or `tf.Tensor` of type bool
        EMA is applied if ``is_training`` is False.
    decay:
        Decay rate for `tf.train.ExponentialMovingAverage`
    """
    def function(fun):
        @functools.wraps(fun)
        def fun_wrapper(*args, **kwargs):
            # Regular call
            with tf.variable_scope('ema_wrapper', reuse=False) as scope:
                result_train = fun(*args, **kwargs)

            # Set up exponential moving average
            ema = tf.train.ExponentialMovingAverage(decay=decay)
            var_class = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                          scope.name)
            ema_op = ema.apply(var_class)

            # Add to collection so they are updated
            tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, ema_op)

            # Getter for the variables with EMA applied
            def ema_getter(getter, name, *args, **kwargs):
                var = getter(name, *args, **kwargs)
                ema_var = ema.average(var)
                return ema_var if ema_var else var

            # Call with EMA applied
            with tf.variable_scope('ema_wrapper', reuse=True,
                                   custom_getter=ema_getter):
                result_test = fun(*args, **kwargs)

            # Return the correct version depending on if we're training or not
            return tf.cond(is_training,
                           lambda: result_train, lambda: result_test)
        return fun_wrapper
    return function

Example usage:

@ema_wrapper(is_training)
def neural_network(x):
    # If is_training is False, we will use an EMA of a instead
    a = tf.get_variable('a', [], tf.float32)
    return a * x
like image 396
Jonas Adler Avatar asked Mar 07 '18 09:03

Jonas Adler


1 Answers

You can have an op that transfers the value from the EMA variables to the original ones:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)
# Make EMA object and update interal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
    train_op = ema.apply(model_vars)

# Transfer EMA values to original variables
retrieve_ema_weights_op = tf.group(
    [tf.assign(var, ema.average(var)) for var in model_vars])

with tf.Session() as sess:
    # Do training
    while ...:
        sess.run(train_op, ...)
    # Copy EMA values to weights
    sess.run(retrieve_ema_weights_op)
    # Test model with EMA weights
    # ...

EDIT:

I made a longer version with the ability to switch between train and test mode with variable backups:

import tensorflow as tf

# Make model...
minimize_op = ...
model_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)

is_training = tf.get_variable('is_training', shape=(), dtype=tf.bool,
                              initializer=tf.constant_initializer(True, dtype=tf.bool))

# Make EMA object and update internal variables after optimization step
ema = tf.train.ExponentialMovingAverage(decay=decay)
with tf.control_dependencies([minimize_op]):
    train_op = ema.apply(model_vars)
# Make backup variables
with tf.variable_scope('BackupVariables'):
    backup_vars = [tf.get_variable(var.op.name, dtype=var.value().dtype, trainable=False,
                                   initializer=var.initialized_value())
                   for var in model_vars]

def ema_to_weights():
    return tf.group(*(tf.assign(var, ema.average(var).read_value())
                     for var in model_vars))
def save_weight_backups():
    return tf.group(*(tf.assign(bck, var.read_value())
                     for var, bck in zip(model_vars, backup_vars)))
def restore_weight_backups():
    return tf.group(*(tf.assign(var, bck.read_value())
                     for var, bck in zip(model_vars, backup_vars)))

def to_training():
    with tf.control_dependencies([tf.assign(is_training, True)]):
        return restore_weight_backups()

def to_testing():
    with tf.control_dependencies([tf.assign(is_training, False)]):
        with tf.control_dependencies([save_weight_backups()]):
            return ema_to_weights()

switch_to_train_mode_op = tf.cond(is_training, lambda: tf.group(), to_training)
switch_to_test_mode_op = tf.cond(is_training, to_testing, lambda: tf.group())

init_op = tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    # Unnecessary, since it begins in training mode, but unharmful
    sess.run(switch_to_train_mode_op)
    # Do training
    while ...:
        sess.run(train_op, ...)
    # To test mode
    sess.run(switch_to_test_mode_op)
    # Switching multiple times should not overwrite backups
    sess.run(switch_to_test_mode_op)
    # Test model with EMA weights
    # ...
    # Back to training mode
    sess.run(switch_to_train_mode_op)
    # Keep training...
like image 88
jdehesa Avatar answered Sep 28 '22 15:09

jdehesa