Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to do fully connected batch norm in PyTorch?

torch.nn has classes BatchNorm1d, BatchNorm2d, BatchNorm3d, but it doesn't have a fully connected BatchNorm class? What is the standard way of doing normal Batch Norm in PyTorch?

like image 733
patapouf_ai Avatar asked Nov 09 '17 09:11

patapouf_ai


People also ask

Does VGG16 use batch normalization?

In this paper, two models reinvented from VGG16 are created with and without using Batch Normalization to evaluate their model performance. It is clear that the model using Batch Normalization provides a better result in terms of Loss Function value and model accuracy, which also achieves a very high accuracy rate.

Is batch Norm better than dropout?

Batch normalization(BN) has been known to improve model performance, mitigate internal covariate shift, and apply a small regularization effect. Such functionalities of the BN and empirical studies proving the effectiveness of BN helped to solidify people's preference of using BN over dropout.

How is batch norm calculated?

Unlike the input layer, which requires all normalized values to have zero mean and unit variance, Batch Norm allows its values to be shifted (to a different mean) and scaled (to a different variance). It does this by multiplying the normalized values by a factor, gamma, and adding to it a factor, beta.


2 Answers

The BatchNorm1d normally comes before the ReLU, and the bias is redundant, so

import torch.nn as nn

class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
    super(Policy2, self).__init__()
    self.action_space = action_space
    num_outputs = action_space

    self.linear1 = nn.Linear(num_inputs, hidden_size1, bias=False)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2, bias=False)
    self.linear3 = nn.Linear(hidden_size2, num_outputs)
    self.bn1 = nn.BatchNorm1d(hidden_size1)
    self.bn2 = nn.BatchNorm1d(hidden_size2)

def forward(self, inputs):
    x = inputs
    x = F.relu(self.bn1(self.linear1(x)))
    x = F.relu(self.bn2(self.linear2(x)))
    out = self.linear3(x)

    return out
like image 77
cdude Avatar answered Oct 13 '22 16:10

cdude


Ok. I figured it out. BatchNorm1d can also handle Rank-2 tensors, thus it is possible to use BatchNorm1d for the normal fully-connected case.

So for example:

import torch.nn as nn


class Policy(nn.Module):
def __init__(self, num_inputs, action_space, hidden_size1=256, hidden_size2=128):
    super(Policy, self).__init__()
    self.action_space = action_space
    num_outputs = action_space

    self.linear1 = nn.Linear(num_inputs, hidden_size1)
    self.linear2 = nn.Linear(hidden_size1, hidden_size2)
    self.linear3 = nn.Linear(hidden_size2, num_outputs)
    self.bn1 = nn.BatchNorm1d(hidden_size1)
    self.bn2 = nn.BatchNorm1d(hidden_size2)

def forward(self, inputs):
    x = inputs
    x = self.bn1(F.relu(self.linear1(x)))
    x = self.bn2(F.relu(self.linear2(x)))
    out = self.linear3(x)


    return out
like image 29
patapouf_ai Avatar answered Oct 13 '22 18:10

patapouf_ai