Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Batch normalization with 3D convolutions in TensorFlow

I'm implementing a model relying on 3D convolutions (for a task that is similar to action recognition) and I want to use batch normalization (see [Ioffe & Szegedy 2015]). I could not find any tutorial focusing on 3D convs, hence I'm making a short one here which I'd like to review with you.

The code below refers to TensorFlow r0.12 and it explicitly instances variables - I mean I'm not using tf.contrib.learn except for the tf.contrib.layers.batch_norm() function. I'm doing this both to better understand how things work under the hood and to have more implementation freedom (e.g., variable summaries).

I will get to the 3D convolution case smoothly by first writing the example for a fully-connected layer, then for a 2D convolution and finally for the 3D case. While going through the code, it would be great if you could check if everything is done correctly - the code runs, but I'm not 100% sure about the way I apply batch normalization. I end this post with a more detailed question.

import tensorflow as tf

# This flag is used to allow/prevent batch normalization params updates
# depending on whether the model is being trained or used for prediction.
training = tf.placeholder_with_default(True, shape=())

Fully-connected (FC) case

# Input.
INPUT_SIZE = 512
u = tf.placeholder(tf.float32, shape=(None, INPUT_SIZE))

# FC params: weights only, no bias as per [Ioffe & Szegedy 2015].
FC_OUTPUT_LAYER_SIZE = 1024
w = tf.Variable(tf.truncated_normal(
    [INPUT_SIZE, FC_OUTPUT_LAYER_SIZE], dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
fc = tf.matmul(u, w)

# Batch normalization.
fc_bn = tf.contrib.layers.batch_norm(
    fc,
    center=True,
    scale=True,
    is_training=training,
    scope='fc-batch_norm')

# Activation function.
fc_bn_relu = tf.nn.relu(fc_bn)
print(fc_bn_relu)  # Tensor("Relu:0", shape=(?, 1024), dtype=float32)

2D convolutional (CNN) layer case

# Input: 640x480 RGB images (whitened input, hence tf.float32).
INPUT_HEIGHT = 480
INPUT_WIDTH = 640
INPUT_CHANNELS = 3
u = tf.placeholder(tf.float32, shape=(None, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS))

# CNN params: wights only, no bias as per [Ioffe & Szegedy 2015].
CNN_FILTER_HEIGHT = 3  # Space dimension.
CNN_FILTER_WIDTH = 3  # Space dimension.
CNN_FILTERS = 128
w = tf.Variable(tf.truncated_normal(
    [CNN_FILTER_HEIGHT, CNN_FILTER_WIDTH, INPUT_CHANNELS, CNN_FILTERS],
    dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
CNN_LAYER_STRIDE_VERTICAL = 1
CNN_LAYER_STRIDE_HORIZONTAL = 1
CNN_LAYER_PADDING = 'SAME'
cnn = tf.nn.conv2d(
    input=u, filter=w,
    strides=[1, CNN_LAYER_STRIDE_VERTICAL, CNN_LAYER_STRIDE_HORIZONTAL, 1],
    padding=CNN_LAYER_PADDING)

# Batch normalization.
cnn_bn = tf.contrib.layers.batch_norm(
    cnn,
    data_format='NHWC',  # Matching the "cnn" tensor which has shape (?, 480, 640, 128).
    center=True,
    scale=True,
    is_training=training,
    scope='cnn-batch_norm')

# Activation function.
cnn_bn_relu = tf.nn.relu(cnn_bn)
print(cnn_bn_relu)  # Tensor("Relu_1:0", shape=(?, 480, 640, 128), dtype=float32)

3D convolutional (CNN3D) layer case

# Input: sequence of 9 160x120 RGB images (whitened input, hence tf.float32).
INPUT_SEQ_LENGTH = 9
INPUT_HEIGHT = 120
INPUT_WIDTH = 160
INPUT_CHANNELS = 3
u = tf.placeholder(tf.float32, shape=(None, INPUT_SEQ_LENGTH, INPUT_HEIGHT, INPUT_WIDTH, INPUT_CHANNELS))

# CNN params: wights only, no bias as per [Ioffe & Szegedy 2015].
CNN3D_FILTER_LENGHT = 3  # Time dimension.
CNN3D_FILTER_HEIGHT = 3  # Space dimension.
CNN3D_FILTER_WIDTH = 3  # Space dimension.
CNN3D_FILTERS = 96
w = tf.Variable(tf.truncated_normal(
    [CNN3D_FILTER_LENGHT, CNN3D_FILTER_HEIGHT, CNN3D_FILTER_WIDTH, INPUT_CHANNELS, CNN3D_FILTERS],
    dtype=tf.float32, stddev=1e-1))

# Layer output with no activation function (yet).
CNN3D_LAYER_STRIDE_TEMPORAL = 1
CNN3D_LAYER_STRIDE_VERTICAL = 1
CNN3D_LAYER_STRIDE_HORIZONTAL = 1
CNN3D_LAYER_PADDING = 'SAME'
cnn3d = tf.nn.conv3d(
    input=u, filter=w,
    strides=[1, CNN3D_LAYER_STRIDE_TEMPORAL, CNN3D_LAYER_STRIDE_VERTICAL, CNN3D_LAYER_STRIDE_HORIZONTAL, 1],
    padding=CNN3D_LAYER_PADDING)

# Batch normalization.
cnn3d_bn = tf.contrib.layers.batch_norm(
    cnn3d,
    data_format='NHWC',  # Matching the "cnn" tensor which has shape (?, 9, 120, 160, 96).
    center=True,
    scale=True,
    is_training=training,
    scope='cnn3d-batch_norm')

# Activation function.
cnn3d_bn_relu = tf.nn.relu(cnn3d_bn)
print(cnn3d_bn_relu)  # Tensor("Relu_2:0", shape=(?, 9, 120, 160, 96), dtype=float32)

What I would like to make sure is whether the code above exactly implements batch normalization as described in [Ioffe & Szegedy 2015] at the end of Sec. 3.2:

For convolutional layers, we additionally want the normalization to obey the convolutional property – so that different elements of the same feature map, at different locations, are normalized in the same way. To achieve this, we jointly normalize all the activations in a minibatch, over all locations. [...] Alg. 2 is modified similarly, so that during inference the BN transform applies the same linear transformation to each activation in a given feature map.

UPDATE I guess the code above is also correct for the 3D conv case. In fact, when I define my model if I print all the trainable variables, I also see the expected numbers of beta and gamma variables. For instance:

Tensor("conv3a/conv3d_weights/read:0", shape=(3, 3, 3, 128, 256), dtype=float32)
Tensor("BatchNorm_2/beta/read:0", shape=(256,), dtype=float32)
Tensor("BatchNorm_2/gamma/read:0", shape=(256,), dtype=float32)

This looks ok to me since due to BN, one pair of beta and gamma are learned for each feature map (256 in total).


[Ioffe & Szegedy 2015]: Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

like image 423
Alessio B Avatar asked Jan 24 '17 14:01

Alessio B


1 Answers

That is a great post about 3D batchnorm, it's often unnoticed that batchnorm can be applied to any tensor of rank greater than 1. Your code is correct, but I couldn't help but add a few important notes on this:

  • A "standard" 2D batchnorm (accepts a 4D tensor) can be significantly faster in tensorflow than 3D or higher, because it supports fused_batch_norm implementation, which applies one kernel operation:

    Fused batch norm combines the multiple operations needed to do batch normalization into a single kernel. Batch norm is an expensive process that for some models makes up a large percentage of the operation time. Using fused batch norm can result in a 12%-30% speedup.

    There is an issue on GitHub to support 3D filters as well, but there hasn't been any recent activity and at this point the issue is closed unresolved.

  • Although the original paper prescribes using batchnorm before ReLU activation (and that's what you did in the code above), there is evidence that it's probably better to use batchnorm after the activation. Here's a comment on Keras GitHub by Francois Chollet:

    ... I can guarantee that recent code written by Christian [Szegedy] applies relu before BN. It is still occasionally a topic of debate, though.

  • For anyone interested to apply the idea of normalization in practice, there's been recent research developments of this idea, namely weight normalization and layer normalization, which fix certain disadvantages of original batchnorm, for example they work better for LSTM and recurrent networks.

like image 59
Maxim Avatar answered Oct 10 '22 20:10

Maxim