What are the differences between torch.flatten()
and torch.nn.Flatten()
?
flatten flattens all dimensions by default, while torch. nn. Flatten flattens all dimensions starting from the second dimension (index 1) by default. You can see this behaviour in the default values of the start_dim and end_dim arguments.
torch.nn.Module. It is a base class used to develop all neural network models. torch.nn.Sequential() It is a sequential Container used to combine different layers to create a feed-forward network.
A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements. Flattening a tensor means to remove all of the dimensions except for one.
view returns a pointer to the tensor, so any changes to original tensor are tracked in the viewed tensor as well. torch. reshape returns an entirely new tensor so that any changes in the original tensor are not reflected in the reshaped tensor.
Flattening is available in three forms in PyTorch
As a tensor method (oop style) torch.Tensor.flatten
applied directly on a tensor: x.flatten()
.
As a function (functional form) torch.flatten
applied as: torch.flatten(x)
.
As a module (layer nn.Module
) nn.Flatten()
. Generally used in a model definition.
All three are identical and share the same implementation, the only difference being nn.Flatten
has start_dim
set to 1
by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0
to axis=-1
- i.e. the entire tensor - if no arguments are given.
You can think of the job of torch.flatten()
as to simply doing a flattening operation of the tensor, without any strings attached. You give a tensor, it flattens, and returns it. That's all there to it.
On the contrary, nn.Flatten()
is much more sophisticated (i.e., it's a neural net layer). Being object oriented, it inherits from nn.Module
, although it internally uses the plain tensor.flatten() OP in the forward()
method for flattening the tensor. You can think of it more like a syntactic sugar over torch.flatten()
.
Important difference: A notable distinction is that torch.flatten()
always returns an 1D tensor as result, provided that the input is at least 1D or greater, whereas nn.Flatten()
always returns a 2D tensor, provided that the input is at least 2D or greater (With 1D tensor as input, it will throw an IndexError).
torch.flatten()
is an API whereas nn.Flatten()
is a neural net layer.
torch.flatten()
is a python function whereas nn.Flatten()
is a python class.
because of the above point, nn.Flatten()
comes with lot of methods and attributes
torch.flatten()
can be used in the wild (e.g., for simple tensor OPs) whereas nn.Flatten()
is expected to be used in a nn.Sequential()
block as one of the layers.
torch.flatten()
has no information about the computation graph unless it is stuck into other graph-aware block (with tensor.requires_grad
flag set to True
) whereas nn.Flatten()
is always being tracked by autograd.
torch.flatten()
cannot accept and process (e.g., linear/conv1D) layers as inputs whereas nn.Flatten()
is mostly used for processing these neural net layers.
both torch.flatten()
and nn.Flatten()
return views to input tensor. Thus, any modification to the result also affects the input tensor. (See the code below)
Code demo:
# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1) # 3D tensor
Flattening with torch.flatten()
:
In [113]: t1flat = torch.flatten(t1)
In [114]: t1flat
Out[114]: tensor([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11])
# modification to the flattened tensor
In [115]: t1flat[-1] = -1
# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]:
tensor([[ 0, 1, 2, 3],
[ 4, 5, 6, 7],
[ 8, 9, 10, -1]])
Flattening with nn.Flatten()
:
In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)
# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]:
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27],
[28, 29, 30, 31, 32, 33, 34, 35]])
# modification to the result
In [126]: t3flat[-1, -1] = -1
# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]:
tensor([[[12, 13, 14, 15],
[16, 17, 18, 19]],
[[20, 21, 22, 23],
[24, 25, 26, 27]],
[[28, 29, 30, 31],
[32, 33, 34, -1]]])
tidbit: torch.flatten()
is the precursor to nn.Flatten()
and its brethren nn.Unflatten()
since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten()
, since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245.
There are also recent proposals to use nn.Flatten()
in ResNets for model surgery.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With