Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using nn.Linear() and nn.BatchNorm1d() together

Tags:

pytorch

I don't understand how BatchNorm1d works when the data is 3D, (batch size, H, W).

Example

  • Input size: (2,50,70)
  • Layer: nn.Linear(70,20)
  • Output size: (2,50,20)

If I then include a batch normalisation layer it requires num_features=50:

  • BN : nn.BatchNorm1d(50)

and I don't understand why it isn't 20:

  • BN : nn.BatchNorm1d(20)

Example 1)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bn11 = nn.BatchNorm1d(50)
        self.fc11 = nn.Linear(70,20)

    def forward(self, inputs):
        out = self.fc11(inputs)
        out = torch.relu(self.bn11(out))
        return out

model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)

Example 2)

class Net(nn.Module):
    def __init__(self):
        super(Net,self).__init__()
        self.bn11 = nn.BatchNorm1d(20)
        self.fc11 = nn.Linear(70,20)

    def forward(self, inputs):
        out = self.fc11(inputs)
        out = torch.relu(self.bn11(out))
        return out

model = Net()
inputs = torch.Tensor(2,50,70)
outputs = model(inputs)
  • Example 1 works.
  • Example 2 throws the error:
    • RuntimeError: running_mean should contain 50 elements not 20

2D example:

  • Input size: (2,70)
  • Layer: nn.Linear(70,20)
  • BN: nn.BatchNorm1d(20)

I thought the 20 in the BN layer was due to there being 20 nodes output by the linear layer and each one requires a running means/std for the incoming values.

Why in the 3D case, if the linear layer has 20 output nodes, the BN layer doesn't have 20 features?

like image 317
Niamh Avatar asked May 30 '26 00:05

Niamh


1 Answers

One can find the answer inside torch.nn.Linear documentation.

It takes input of shape (N, *, I) and returns (N, *, O), where I stands for input dimension and O for output dim and * are any dimensions between.

If you pass torch.Tensor(2,50,70) into nn.Linear(70,20), you get output of shape (2, 50, 20) and when you use BatchNorm1d it calculates running mean for first non-batch dimension, so it would be 50. That's the reason behind your error.

like image 169
Szymon Maszke Avatar answered Jun 01 '26 20:06

Szymon Maszke



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!