Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I specify the flatten layer input size after many conv layers in PyTorch?

Here is my problem, I do a small test on CIFAR10 dataset, how can I specify the flatten layer input size in PyTorch? like the following, the input size is 16*5*5, however I don't know how to calculate this and I want to get the input size through some function.Can someone just write a simple function in this Net class and solve this?

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3,6,5)  
        self.conv2 = nn.Conv2d(6,16,5)

        # HERE , the input size is 16*5*5, but I don't know how to get it.
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        x = x.view(x.size()[0],-1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
like image 902
gaoquan liang Avatar asked Sep 24 '18 07:09

gaoquan liang


3 Answers

There is no Flatten Layer in the Pytorch default. You can create a class like below. Cheers

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)


class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten   = Flatten()  ## describing the layer
        self.conv1 = nn.Conv2d(3,6,5)  
        self.conv2 = nn.Conv2d(6,16,5)

        # HERE , the input size is 16*5*5, but I don't know how to get it.
        self.fc1 = nn.Linear(16*5*5, 120)
        self.fc2 = nn.Linear(120,84)
        self.fc3 = nn.Linear(84,10)

    def forward(self, x):
        x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
        x = F.max_pool2d(F.relu(self.conv2(x)),2)
        #x = x.view(x.size()[0],-1)
        x = self.flatten(x)   ### using of flatten layer
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
like image 98
Salih Karagoz Avatar answered Oct 10 '22 19:10

Salih Karagoz


Despite this is old, I'll answer for future readers. 5x5 is the image dimension after all the convolutions and poolings. In the docs, there is a formula to compute this: Hight_out = (Hight_in + 2*padding - dilation*(kernel_size-1)-1)/stride +1. It's the same for the width. So you start with a 32x32 image. After the first convolution layer, you got a 28x28 image because of the kernel size (you lost 2 pixels on each side). After the pooling, you got 14x14, because it is a 2x2 mask with a stride of 2. Then, the second convolution layer gives you a 10x10 image, and finally, the last pooling gives 5x5. Then you multiply for the number of output channels: 16.

I think @trizard answer was close but he misread the kernel size.

like image 23
lvass Avatar answered Oct 10 '22 20:10

lvass


Just answering to update this post, there is now a nn.Flatten() layer in Pytorch as of 1.3:

https://pytorch.org/docs/stable/_modules/torch/nn/modules/flatten.html

Also, worth mentioning that if you can't use >=1.3 and you "need" that CNN output size (for example if you have multiple heads, most people do programmatically get the output from a dummy input, with something like:

def get_flat_fts(self, input_shape, conv_net):
    f = conv_net(Variable(torch.ones(1,*input_shape)))
    return int(np.prod(f.size()[1:]))
like image 25
physincubus Avatar answered Oct 10 '22 20:10

physincubus