Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to flatten input in `nn.Sequential` in Pytorch

how to flatten input inside the nn.Sequential

Model = nn.Sequential(x.view(x.shape[0],-1),
                     nn.Linear(784,256),
                     nn.ReLU(),
                     nn.Linear(256,128),
                     nn.ReLU(),
                     nn.Linear(128,64),
                     nn.ReLU(),
                     nn.Linear(64,10),
                     nn.LogSoftmax(dim=1))
like image 920
Khagendra Avatar asked Dec 28 '18 03:12

Khagendra


People also ask

How do you flatten inputs in PyTorch?

flatten. Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.

What does nn flatten () do?

nn. Flatten flattens all dimensions starting from the second dimension (index 1) by default. You can see this behaviour in the default values of the start_dim and end_dim arguments.

What is flatten in PyTorch?

PyTorchServer Side ProgrammingProgramming. A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch. flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.

How to flatten an input tensor in PyTorch?

How to flatten an input tensor by reshaping it in PyTorch? A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch.flatten (). This method supports both real and complex-valued input tensors.

How does the forward () method of sequential work in Python?

The forward () method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.

How do I flatten a layer in Python?

The fastest way to flatten the layer is not to create the new module and to add that module to the main via main.add_module ('flatten', Flatten ()). Instead, just a simple, out = inp.reshape (inp.size (0), -1) inside forward of your model is faster as I showed in here. Thanks for contributing an answer to Stack Overflow!

How to use sequential to create a small model?

# Using Sequential to create a small model. When `model` is run, # input will first be passed to `Conv2d (1,20,5)`. The output of # `Conv2d (1,20,5)` will be used as the input to the first # `ReLU`; the output of the first `ReLU` will become the input # for `Conv2d (20,64,5)`.


2 Answers

You can create a new module/class as below and use it in the sequential as you are using other modules (call Flatten()).

class Flatten(torch.nn.Module):
    def forward(self, x):
        batch_size = x.shape[0]
        return x.view(batch_size, -1)

Ref: https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983

EDIT: Flatten is part of torch now. See https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten

like image 174
Umang Gupta Avatar answered Sep 20 '22 14:09

Umang Gupta


As being defined flatten method

torch.flatten(input, start_dim=0, end_dim=-1) → Tensor

is speed comparable to view(), but reshape is even faster.

import torch.nn as nn

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)

flatten = Flatten()

t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)


#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)


#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)

Speed check

# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

If we would use class from above

flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)


5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)

This result shows creating a class would be slower approach. This is why it is faster to flatten tensors inside forward. I think this is the main reason they haven't promoted nn.Flatten.

So my suggestion would be to use inside forward for speed. Something like this:

out = inp.reshape(inp.size(0), -1)
like image 42
prosti Avatar answered Sep 19 '22 14:09

prosti