Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Difference between Keras' BatchNormalization and PyTorch's BatchNorm2d?

I've a sample tiny CNN implemented in both Keras and PyTorch. When I print summary of both the networks, the total number of trainable parameters are same but total number of parameters and number of parameters for Batch Normalization don't match.

Here is the CNN implementation in Keras:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

And the summary printed for above model is:

Model: "model_8"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_9 (InputLayer)         (None, 64, 64, 1)         0         
_________________________________________________________________
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
_________________________________________________________________
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
_________________________________________________________________
flatten_3 (Flatten)          (None, 131072)            0         
_________________________________________________________________
dense_11 (Dense)             (None, 100)               13107300  
_________________________________________________________________
dense_12 (Dense)             (None, 10)                1010      
=================================================================
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64
_________________________________________________________________

Here's the implementation of the same model architecture in PyTorch:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
        nn.ReLU(True),
        nn.BatchNorm2d(num_features=32),
    )
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

And following is the output of summary of the above model:

----------------------------------------------------------------
        Layer (type)               Output Shape         Param #
================================================================
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
================================================================
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
----------------------------------------------------------------
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02
----------------------------------------------------------------

As you can see in above results, Batch Normalization in Keras has more number of parameters than PyTorch (2x to be exact). So what's the difference in above CNN architectures? If they are equivalent, then what am I missing here?

like image 793
Kaushal28 Avatar asked Feb 05 '20 16:02

Kaushal28


People also ask

What is BatchNorm2d?

BatchNorm2d is the number of dimensions/channels that output from the last layer and come in to the batch norm layer.

What is BatchNormalization in Pytorch?

Pytorch batch normalization is a process of training the neural network. During training the network this layer keep guessing its computed mean and variance. Code: In the following code, we will import some libraries from which we can train the neural network and also evaluate its computed mean and variance.

What does NN BatchNorm1d do?

Applies Batch Normalization over a 2D or 3D input as described in the paper Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift .

How do you do a batch normalization in keras?

For instance, after a Conv2D layer with data_format="channels_first" , set axis=1 in BatchNormalization . momentum: Momentum for the moving average. epsilon: Small float added to variance to avoid dividing by zero. center: If True, add offset of beta to normalized tensor.

What's the difference between batch normalization in Keras and PyTorch?

As you can see in above results, Batch Normalization in Keras has more number of parameters than PyTorch (2x to be exact). So what's the difference in above CNN architectures? If they are equivalent, then what am I missing here? Keras treats as parameters (weights) many things that will be "saved/loaded" in the layer.

Is Keras a PyTorch library?

However, the Keras library can still operate separately and independently. What is PyTorch? PyTorch is a relatively new deep learning framework based on Torch. Developed by Facebook’s AI research group and open-sourced on GitHub in 2017, it’s used for natural language processing applications.

Is Keras part of TensorFlow?

Keras was adopted and integrated into TensorFlow in mid-2017. Users can access it via the tf.keras module. However, the Keras library can still operate separately and independently. What is PyTorch? PyTorch is a relatively new deep learning framework based on Torch.

What is the default axis for batchnormalization in TensorFlow?

I want to use BatchNorm1D like in PyTorch in TensorFlow. I notice that BatchNormalization () in TF has axis=-1 as default. Which axis is the correct one for BatchNorm1D, BatchNorm2D, BatchNorm3D as in PyTorch?


1 Answers

Keras treats as parameters (weights) many things that will be "saved/loaded" in the layer.

While both implementations naturally have the accumulated "mean" and "variance" of the batches, these values are not trainable with backpropagation.

Nevertheless, these values are updated every batch, and Keras treats them as non-trainable weights, while PyTorch simply hides them. The term "non-trainable" here means "not trainable by backpropagation", but doesn't mean the values are frozen.

In total they are 4 groups of "weights" for a BatchNormalization layer. Considering the selected axis (default = -1, size=32 for your layer)

  • scale (32) - trainable
  • offset (32) - trainable
  • accumulated means (32) - non-trainable, but updated every batch
  • accumulated std (32) - non-trainable, but updated every batch

The advantage of having it like this in Keras is that when you save the layer, you also save the mean and variance values the same way you save all other weights in the layer automatically. And when you load the layer, these weights are loaded together.

like image 75
Daniel Möller Avatar answered Sep 30 '22 04:09

Daniel Möller