Difference between Keras' BatchNormalization and PyTorch's BatchNorm2d?

I've a sample tiny CNN implemented in both Keras and PyTorch. When I print summary of both the networks, the total number of trainable parameters are same but total number of parameters and number of parameters for Batch Normalization don't match.

Here is the CNN implementation in Keras:

inputs = Input(shape = (64, 64, 1)). # Channel Last: (NHWC)

model = Conv2D(filters=32, kernel_size=(3, 3), padding='SAME', activation='relu', input_shape=(IMG_SIZE, IMG_SIZE, 1))(inputs)
model = BatchNormalization(momentum=0.15, axis=-1)(model)
model = Flatten()(model)

dense = Dense(100, activation = "relu")(model)
head_root = Dense(10, activation = 'softmax')(dense)

And the summary printed for above model is:

Model: "model_8"
Layer (type)                 Output Shape              Param #   
input_9 (InputLayer)         (None, 64, 64, 1)         0         
conv2d_10 (Conv2D)           (None, 64, 64, 32)        320       
batch_normalization_2 (Batch (None, 64, 64, 32)        128       
flatten_3 (Flatten)          (None, 131072)            0         
dense_11 (Dense)             (None, 100)               13107300  
dense_12 (Dense)             (None, 10)                1010      
Total params: 13,108,758
Trainable params: 13,108,694
Non-trainable params: 64

Here's the implementation of the same model architecture in PyTorch:

# Image format: Channel first (NCHW) in PyTorch
class CustomModel(nn.Module):
def __init__(self):
    super(CustomModel, self).__init__()
    self.layer1 = nn.Sequential(
        nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(3, 3), padding=1),
    self.flatten = nn.Flatten()
    self.fc1 = nn.Linear(in_features=131072, out_features=100)
    self.fc2 = nn.Linear(in_features=100, out_features=10)

def forward(self, x):
    output = self.layer1(x)
    output = self.flatten(output)
    output = self.fc1(output)
    output = self.fc2(output)
    return output

And following is the output of summary of the above model:

        Layer (type)               Output Shape         Param #
            Conv2d-1           [-1, 32, 64, 64]             320
              ReLU-2           [-1, 32, 64, 64]               0
       BatchNorm2d-3           [-1, 32, 64, 64]              64
           Flatten-4               [-1, 131072]               0
            Linear-5                  [-1, 100]      13,107,300
            Linear-6                   [-1, 10]           1,010
Total params: 13,108,694
Trainable params: 13,108,694
Non-trainable params: 0
Input size (MB): 0.02
Forward/backward pass size (MB): 4.00
Params size (MB): 50.01
Estimated Total Size (MB): 54.02

As you can see in above results, Batch Normalization in Keras has more number of parameters than PyTorch (2x to be exact). So what's the difference in above CNN architectures? If they are equivalent, then what am I missing here?

1 Answers

Keras treats as parameters (weights) many things that will be "saved/loaded" in the layer.

While both implementations naturally have the accumulated "mean" and "variance" of the batches, these values are not trainable with backpropagation.

Nevertheless, these values are updated every batch, and Keras treats them as non-trainable weights, while PyTorch simply hides them. The term "non-trainable" here means "not trainable by backpropagation", but doesn't mean the values are frozen.

In total they are 4 groups of "weights" for a BatchNormalization layer. Considering the selected axis (default = -1, size=32 for your layer)

  • scale (32) - trainable
  • offset (32) - trainable
  • accumulated means (32) - non-trainable, but updated every batch
  • accumulated std (32) - non-trainable, but updated every batch

The advantage of having it like this in Keras is that when you save the layer, you also save the mean and variance values the same way you save all other weights in the layer automatically. And when you load the layer, these weights are loaded together.

