how to flatten input inside the nn.Sequential
Model = nn.Sequential(x.view(x.shape[0],-1),
nn.Linear(784,256),
nn.ReLU(),
nn.Linear(256,128),
nn.ReLU(),
nn.Linear(128,64),
nn.ReLU(),
nn.Linear(64,10),
nn.LogSoftmax(dim=1))
flatten. Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.
nn. Flatten flattens all dimensions starting from the second dimension (index 1) by default. You can see this behaviour in the default values of the start_dim and end_dim arguments.
PyTorchServer Side ProgrammingProgramming. A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch. flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.
How to flatten an input tensor by reshaping it in PyTorch? A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch.flatten (). This method supports both real and complex-valued input tensors.
The forward () method of Sequential accepts any input and forwards it to the first module it contains. It then “chains” outputs to inputs sequentially for each subsequent module, finally returning the output of the last module.
The fastest way to flatten the layer is not to create the new module and to add that module to the main via main.add_module ('flatten', Flatten ()). Instead, just a simple, out = inp.reshape (inp.size (0), -1) inside forward of your model is faster as I showed in here. Thanks for contributing an answer to Stack Overflow!
# Using Sequential to create a small model. When `model` is run, # input will first be passed to `Conv2d (1,20,5)`. The output of # `Conv2d (1,20,5)` will be used as the input to the first # `ReLU`; the output of the first `ReLU` will become the input # for `Conv2d (20,64,5)`.
You can create a new module/class as below and use it in the sequential as you are using other modules (call Flatten()
).
class Flatten(torch.nn.Module):
def forward(self, x):
batch_size = x.shape[0]
return x.view(batch_size, -1)
Ref: https://discuss.pytorch.org/t/flatten-layer-of-pytorch-build-by-sequential-container/5983
EDIT: Flatten
is part of torch now. See https://pytorch.org/docs/stable/nn.html?highlight=flatten#torch.nn.Flatten
As being defined flatten
method
torch.flatten(input, start_dim=0, end_dim=-1) → Tensor
is speed comparable to view()
, but reshape
is even faster.
import torch.nn as nn
class Flatten(nn.Module):
def forward(self, input):
return input.view(input.size(0), -1)
flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
print(t, t.shape)
#https://pytorch.org/docs/master/torch.html#torch.flatten
f = torch.flatten(t, start_dim=1, end_dim=-1)
print(f, f.shape)
#https://pytorch.org/docs/master/torch.html#torch.view
f = t.view(t.size(0), -1)
print(f, f.shape)
#https://pytorch.org/docs/master/torch.html#torch.reshape
f = t.reshape(t.size(0), -1)
print(f, f.shape)
Speed check
# flatten 3.49 µs ± 146 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# view 3.23 µs ± 228 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
# reshape 3.04 µs ± 93 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
If we would use class from above
flatten = Flatten()
t = torch.Tensor(3,2,2).random_(0, 10)
%timeit f=flatten(t)
5.16 µs ± 122 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
This result shows creating a class would be slower approach. This is why it is faster to flatten tensors inside forward. I think this is the main reason they haven't promoted nn.Flatten
.
So my suggestion would be to use inside forward for speed. Something like this:
out = inp.reshape(inp.size(0), -1)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With