Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding PyTorch CNN Channels

I'm a bit confused at how CNNs and channels work. Specifically, how come these two implementations are not equal? Isn't the # of output channels just applying however many # of filters?

    self.conv1 = nn.Conv2d(1, 10, kernel_size=(3, self.embeds_size))
    self.conv2 = nn.ModuleList([nn.Conv2d(1, 1, kernel_size=(3, self.embeds_size)) for f in range(10)])
    ...


    conv1s = self.conv1(x)
    conv2s = [conv(x) for conv in self.conv2]
    conv2s = torch.stack(conv2s, 1).squeeze(2)
    print(torch.equal(conv1s, conv2s))
like image 638
Matt Avatar asked Jan 01 '26 00:01

Matt


1 Answers

Check the state dicts of the different modules. Unless you're doing something fancy that you didn't tell us about, PyTorch will initialize the weights randomly. Specifically, try this:

print(self.conv1.state_dict()["weight"][0])
print(self.conv2[0].state_dict()["weight"][0])

They will be different.

like image 181
Jens Petersen Avatar answered Jan 03 '26 19:01

Jens Petersen



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!