I was trying to use batch normalization to train my Neural Networks using TensorFlow but it was unclear to me how to use the official layer implementation of Batch Normalization (note this is different from the one from the API).
After some painful digging on the their github issues it seems that one needs a tf.cond to use it properly and also a 'resue=True' flag so that the BN shift and scale variables are properly reused. After figuring that out I provided a small description of how I believe is the right way to use it here. 
Now I have written a short script to test it (only a single layer and a ReLu, hard to make it smaller than this). However, I am not 100% sure how to test it. Right now my code runs with no error messages but returns NaNs unexpectedly. Which lowers my confidence that the code I gave in the other post might be right. Or maybe the network I have is weird. Either way, does someone know whats wrong? Here is the code:
import tensorflow as tf
# download and install the MNIST data automatically
from tensorflow.examples.tutorials.mnist import input_data
from tensorflow.contrib.layers.python.layers import batch_norm as batch_norm
def batch_norm_layer(x,train_phase,scope_bn):
    bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
    is_training=True,
    reuse=None, # is this right?
    trainable=True,
    scope=scope_bn)
    bn_inference = batch_norm(x, decay=0.999, center=True, scale=True,
    is_training=False,
    reuse=True, # is this right?
    trainable=True,
    scope=scope_bn)
    z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
    return z
def get_NN_layer(x, input_dim, output_dim, scope, train_phase):
    with tf.name_scope(scope+'vars'):
        W = tf.Variable(tf.truncated_normal(shape=[input_dim, output_dim], mean=0.0, stddev=0.1))
        b = tf.Variable(tf.constant(0.1, shape=[output_dim]))
    with tf.name_scope(scope+'Z'):
        z = tf.matmul(x,W) + b
    with tf.name_scope(scope+'BN'):
        if train_phase is not None:
            z = batch_norm_layer(z,train_phase,scope+'BN_unit')
    with tf.name_scope(scope+'A'):
        a = tf.nn.relu(z) # (M x D1) = (M x D) * (D x D1)
    return a
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
# placeholder for data
x = tf.placeholder(tf.float32, [None, 784])
# placeholder that turns BN during training or off during inference
train_phase = tf.placeholder(tf.bool, name='phase_train')
# variables for parameters
hiden_units = 25
layer1 = get_NN_layer(x, input_dim=784, output_dim=hiden_units, scope='layer1', train_phase=train_phase)
# create model
W_final = tf.Variable(tf.truncated_normal(shape=[hiden_units, 10], mean=0.0, stddev=0.1))
b_final = tf.Variable(tf.constant(0.1, shape=[10]))
y = tf.nn.softmax(tf.matmul(layer1, W_final) + b_final)
### training
y_ = tf.placeholder(tf.float32, [None, 10])
cross_entropy = tf.reduce_mean( -tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1]) )
train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
with tf.Session() as sess:
    sess.run(tf.initialize_all_variables())
    steps = 3000
    for iter_step in xrange(steps):
        #feed_dict_batch = get_batch_feed(X_train, Y_train, M, phase_train)
        batch_xs, batch_ys = mnist.train.next_batch(100)
        # Collect model statistics
        if iter_step%1000 == 0:
            batch_xstrain, batch_xstrain = batch_xs, batch_ys #simualtes train data
            batch_xcv, batch_ycv = mnist.test.next_batch(5000) #simualtes CV data
            batch_xtest, batch_ytest = mnist.test.next_batch(5000) #simualtes test data
            # do inference
            train_error = sess.run(fetches=cross_entropy, feed_dict={x: batch_xs, y_:batch_ys, train_phase: False})
            cv_error = sess.run(fetches=cross_entropy, feed_dict={x: batch_xcv, y_:batch_ycv, train_phase: False})
            test_error = sess.run(fetches=cross_entropy, feed_dict={x: batch_xtest, y_:batch_ytest, train_phase: False})
            def do_stuff_with_errors(*args):
                print args
            do_stuff_with_errors(train_error, cv_error, test_error)
        # Run Train Step
        sess.run(fetches=train_step, feed_dict={x: batch_xs, y_:batch_ys, train_phase: True})
    # list of booleans indicating correct predictions
    correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
    # accuracy
    accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
    print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels, train_phase: False}))
when I run it I get:
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(2.3474066, 2.3498712, 2.3461707)
(0.49414295, 0.88536006, 0.91152304)
(0.51632041, 0.393666, nan)
0.9296
it used to be all the last ones were nan and now only a few of them. Is everything fine or am I paranoic?
Batch Norm is a normalization technique done between the layers of a Neural Network instead of in the raw data. It is done along mini-batches instead of the full data set. It serves to speed up training and use higher learning rates, making learning easier. the standard deviation of the neurons' output.
A Normalization layer should always either be adapted over a dataset or passed mean and variance . During adapt() , the layer will compute a mean and variance separately for each position in each axis specified by the axis argument. To calculate a single mean and variance over the input data, simply pass axis=None .
Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1. Importantly, batch normalization works differently during training and during inference.
When to use Batch Normalization? We can use Batch Normalization in Convolution Neural Networks, Recurrent Neural Networks, and Artificial Neural Networks. In practical coding, we add Batch Normalization after the activation function of the output layer or before the activation function of the input layer.
I am not sure if this will solve your problem, the documentation for BatchNorm is not quite easy-to-use/informative, so here is a short recap on how to use simple BatchNorm:
First of all, you define your BatchNorm layer. If you want to use it after an affine/fully-connected layer, you do this (just an example, order can be different/as you desire):
...
inputs = tf.matmul(inputs, W) + b
inputs = tf.layers.batch_normalization(inputs, training=is_training)
inputs = tf.nn.relu(inputs)
...
The function tf.layers.batch_normalization calls variable-initializers. These are internal-variables and need a special scope to be called, which is in the tf.GraphKeys.UPDATE_OPS. As such, you must call your optimizer function as follows (after all layers have been defined!):
...
extra_update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(extra_update_ops):
    trainer = tf.train.AdamOptimizer() 
    updateModel = trainer.minimize(loss, global_step=global_step)
...
You can read more about it here. I know it's a little late to answer your question, but it might help other people coming across BatchNorm problems in tensorflow! :)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With