What is the significance of "trainable" and "training" flag in tf.layers.batch_normalization? How are these two different during training and prediction?
Consequently, batch normalization adds two trainable parameters to each layer, so the normalized output is multiplied by a “standard deviation” parameter (gamma) and add a “mean” parameter (beta).
Batch normalization applies a transformation that maintains the mean output close to 0 and the output standard deviation close to 1. Importantly, batch normalization works differently during training and during inference.
Batch normalization solves a major problem called internal covariate shift. It helps by making the data flowing between intermediate layers of the neural network look, this means you can use a higher learning rate. It has a regularizing effect which means you can often remove dropout.
In conclusion, Normalization layers in the model often helps to speed up and stabilize the learning process. If training with large batches isn't an issue and if the network doesn't have any recurrent connections, Batch Normalization could be used.
The batch norm has two phases:
1. Training:
- Normalize layer activations using `moving_avg`, `moving_var`, `beta` and `gamma`
(`training`* should be `True`.)
- update the `moving_avg` and `moving_var` statistics.
(`trainable` should be `True`)
2. Inference:
- Normalize layer activations using `beta` and `gamma`.
(`training` should be `False`)
Example code to illustrate few cases:
#random image
img = np.random.randint(0,10,(2,2,4)).astype(np.float32)
# batch norm params initialized
beta = np.ones((4)).astype(np.float32)*1 # all ones
gamma = np.ones((4)).astype(np.float32)*2 # all twos
moving_mean = np.zeros((4)).astype(np.float32) # all zeros
moving_var = np.ones((4)).astype(np.float32) # all ones
#Placeholders for input image
_input = tf.placeholder(tf.float32, shape=(1,2,2,4), name='input')
#batch Norm
out = tf.layers.batch_normalization(
_input,
beta_initializer=tf.constant_initializer(beta),
gamma_initializer=tf.constant_initializer(gamma),
moving_mean_initializer=tf.constant_initializer(moving_mean),
moving_variance_initializer=tf.constant_initializer(moving_var),
training=False, trainable=False)
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
init_op = tf.global_variables_initializer()
## 2. Run the graph in a session
with tf.Session() as sess:
# init the variables
sess.run(init_op)
for i in range(2):
ops, o = sess.run([update_ops, out], feed_dict={_input: np.expand_dims(img, 0)})
print('beta', sess.run('batch_normalization/beta:0'))
print('gamma', sess.run('batch_normalization/gamma:0'))
print('moving_avg',sess.run('batch_normalization/moving_mean:0'))
print('moving_variance',sess.run('batch_normalization/moving_variance:0'))
print('out', np.round(o))
print('')
When training=False
and trainable=False
:
img = [[[4., 5., 9., 0.]...
out = [[ 9. 11. 19. 1.]...
The activation is scaled/shifted using gamma and beta.
When training=True
and trainable=False
:
out = [[ 2. 2. 3. -1.] ...
The activation is normalized using `moving_avg`, `moving_var`, `gamma` and `beta`.
The averages are not updated.
When traning=True
and trainable=True
:
The out is same as above, but the `moving_avg` and `moving_var` gets updated to new values.
moving_avg [0.03249997 0.03499997 0.06499994 0.02749997]
moving_variance [1.0791667 1.1266665 1.0999999 1.0925]
This is quite complicated. And in TF 2.0 the behavior is changed, see this:
https://github.com/tensorflow/tensorflow/blob/095272a4dd259e8acd3bc18e9eb5225e7a4d7476/tensorflow/python/keras/layers/normalization_v2.py#L26
About setting
layer.trainable = False
on aBatchNormalization
layer:The meaning of setting
layer.trainable = False
is to freeze the layer, i.e. its internal state will not change during training:
its trainable weights will not be updated duringfit()
ortrain_on_batch()
, and its state updates will not be run. Usually, this does not necessarily mean that the layer is run in inference
mode (which is normally controlled by thetraining
argument that can be passed when calling a layer). "Frozen state" and "inference mode"
are two separate concepts.However, in the case of the
BatchNormalization
layer, settingtrainable = False
on the layer means that the layer will be
subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch,
rather than using the mean and variance of the current batch). This behavior has been introduced in TensorFlow 2.0, in order to enablelayer.trainable = False
to produce the most commonly expected behavior in the convnet fine-tuning use case. Note that:
- This behavior only occurs as of TensorFlow 2.0. In 1.*, setting
layer.trainable = False
would freeze the layer but would not switch it to inference mode.- Setting
trainable
on an model containing other layers will recursively set thetrainable
value of all inner layers.- If the value of the
trainable
attribute is changed after callingcompile()
on a model, the new value doesn't take effect for this model untilcompile()
is called again.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With