torch.nn
has classes BatchNorm1d
, BatchNorm2d
, BatchNorm3d
, but it doesn't have a fully connected BatchNorm class? What is the standard way of doing normal Batch Norm in PyTorch?
In this paper, two models reinvented from VGG16 are created with and without using Batch Normalization to evaluate their model performance. It is clear that the model using Batch Normalization provides a better result in terms of Loss Function value and model accuracy, which also achieves a very high accuracy rate.
Batch normalization(BN) has been known to improve model performance, mitigate internal covariate shift, and apply a small regularization effect. Such functionalities of the BN and empirical studies proving the effectiveness of BN helped to solidify people's preference of using BN over dropout.
Unlike the input layer, which requires all normalized values to have zero mean and unit variance, Batch Norm allows its values to be shifted (to a different mean) and scaled (to a different variance). It does this by multiplying the normalized values by a factor, gamma, and adding to it a factor, beta.
The BatchNorm1d normally comes before the ReLU, and the bias is redundant, so
import torch.nn as nn
class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
super(Policy2, self).__init__()
self.action_space = action_space
num_outputs = action_space
self.linear1 = nn.Linear(num_inputs, hidden_size1, bias=False)
self.linear2 = nn.Linear(hidden_size1, hidden_size2, bias=False)
self.linear3 = nn.Linear(hidden_size2, num_outputs)
self.bn1 = nn.BatchNorm1d(hidden_size1)
self.bn2 = nn.BatchNorm1d(hidden_size2)
def forward(self, inputs):
x = inputs
x = F.relu(self.bn1(self.linear1(x)))
x = F.relu(self.bn2(self.linear2(x)))
out = self.linear3(x)
return out
Ok. I figured it out. BatchNorm1d
can also handle Rank-2 tensors, thus it is possible to use BatchNorm1d
for the normal fully-connected case.
So for example:
import torch.nn as nn
class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
super(Policy, self).__init__()
self.action_space = action_space
num_outputs = action_space
self.linear1 = nn.Linear(num_inputs, hidden_size1)
self.linear2 = nn.Linear(hidden_size1, hidden_size2)
self.linear3 = nn.Linear(hidden_size2, num_outputs)
self.bn1 = nn.BatchNorm1d(hidden_size1)
self.bn2 = nn.BatchNorm1d(hidden_size2)
def forward(self, inputs):
x = inputs
x = self.bn1(F.relu(self.linear1(x)))
x = self.bn2(F.relu(self.linear2(x)))
out = self.linear3(x)
return out
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With