Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Batch Normalization in tensorflow

I noticed there are batch normalization functions already in the api for tensorflow. One thing I don't understand though, is how to to change the procedure between training and test?

Batch normalization acts differently during test than during training. Specifically one uses a fixed mean and variance during training.

Is there some good example code somewhere? I saw some, but with scope variables it got confusing

like image 427
user3358117 Avatar asked May 02 '16 09:05

user3358117


1 Answers

You are right, the tf.nn.batch_normalization provides just the basic functionality for implementing batch normalization. You have to add the extra logic to keep track of moving means and variances during training, and use the trained means and variances during inference. You can look at this example for a very general implementation, but a quick version that doesn't use gamma is here :

  beta = tf.Variable(tf.zeros(shape), name='beta')
  moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean',
                                 trainable=False)
  moving_variance = tf.Variable(tf.ones(shape),
                                     name='moving_variance',
                                     trainable=False)
  control_inputs = []
  if is_training:
    mean, variance = tf.nn.moments(image, [0, 1, 2])
    update_moving_mean = moving_averages.assign_moving_average(
        moving_mean, mean, self.decay)
    update_moving_variance = moving_averages.assign_moving_average(
        moving_variance, variance, self.decay)
    control_inputs = [update_moving_mean, update_moving_variance]
  else:
    mean = moving_mean
    variance = moving_variance
  with tf.control_dependencies(control_inputs):
    return tf.nn.batch_normalization(
        image, mean=mean, variance=variance, offset=beta,
        scale=None, variance_epsilon=0.001)
like image 61
keveman Avatar answered Sep 20 '22 12:09

keveman