Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Flatten layer of PyTorch build by sequential container

I am trying to build a cnn by sequential container of PyTorch, my problem is I cannot figure out how to flatten the layer.

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', make_it_flatten)

What should I put in the "make_it_flatten"? I tried to flatten the main but it do not work, main do not exist something call view

main = main.view(-1, 16*3*3)
like image 804
StereoMatching Avatar asked Aug 09 '17 08:08

StereoMatching


2 Answers

This might not be exactly what you are looking for, but you can simply create your own nn.Module that flattens any input, which you can then add to the nn.Sequential() object:

class Flatten(nn.Module):
    def forward(self, x):
        return x.view(x.size()[0], -1)

The x.size()[0] will select the batch dim, and -1 will compute all remaining dims to fit the number of elements, thereby flattening any tensor/Variable.

And using it in nn.Sequential:

main = nn.Sequential()
self._conv_block(main, 'conv_0', 3, 6, 5)
main.add_module('max_pool_0_2_2', nn.MaxPool2d(2,2))
self._conv_block(main, 'conv_1', 6, 16, 3)
main.add_module('max_pool_1_2_2', nn.MaxPool2d(2,2)) 
main.add_module('flatten', Flatten())
like image 84
cleros Avatar answered Oct 26 '22 13:10

cleros


The fastest way to flatten the layer is not to create the new module and to add that module to the main via main.add_module('flatten', Flatten()).

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

Instead, just a simple, out = inp.reshape(inp.size(0), -1) inside forward of your model is faster as I showed in here.

like image 21
prosti Avatar answered Oct 26 '22 13:10

prosti