Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Difference between torch.flatten() and nn.Flatten()

What are the differences between torch.flatten() and torch.nn.Flatten()?

like image 390
jules Avatar asked Feb 01 '21 13:02

jules


People also ask

What is Torch nn flatten?

flatten flattens all dimensions by default, while torch. nn. Flatten flattens all dimensions starting from the second dimension (index 1) by default. You can see this behaviour in the default values of the start_dim and end_dim arguments.

What is Torch nn module?

torch.nn.Module. It is a base class used to develop all neural network models. torch.nn.Sequential() It is a sequential Container used to combine different layers to create a feed-forward network.

What does flattening a tensor do?

A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements. Flattening a tensor means to remove all of the dimensions except for one.

What is the difference between reshape and view PyTorch?

view returns a pointer to the tensor, so any changes to original tensor are tracked in the viewed tensor as well. torch. reshape returns an entirely new tensor so that any changes in the original tensor are not reflected in the reshaped tensor.


2 Answers

Flattening is available in three forms in PyTorch

  • As a tensor method (oop style) torch.Tensor.flatten applied directly on a tensor: x.flatten().

  • As a function (functional form) torch.flatten applied as: torch.flatten(x).

  • As a module (layer nn.Module) nn.Flatten(). Generally used in a model definition.

All three are identical and share the same implementation, the only difference being nn.Flatten has start_dim set to 1 by default to avoid flattening the first axis (usually the batch axis). While the other two flatten from axis=0 to axis=-1 - i.e. the entire tensor - if no arguments are given.

like image 109
Ivan Avatar answered Sep 30 '22 16:09

Ivan


You can think of the job of torch.flatten() as to simply doing a flattening operation of the tensor, without any strings attached. You give a tensor, it flattens, and returns it. That's all there to it.

On the contrary, nn.Flatten() is much more sophisticated (i.e., it's a neural net layer). Being object oriented, it inherits from nn.Module, although it internally uses the plain tensor.flatten() OP in the forward() method for flattening the tensor. You can think of it more like a syntactic sugar over torch.flatten().


Important difference: A notable distinction is that torch.flatten() always returns an 1D tensor as result, provided that the input is at least 1D or greater, whereas nn.Flatten() always returns a 2D tensor, provided that the input is at least 2D or greater (With 1D tensor as input, it will throw an IndexError).


Comparisons:

  • torch.flatten() is an API whereas nn.Flatten() is a neural net layer.

  • torch.flatten() is a python function whereas nn.Flatten() is a python class.

  • because of the above point, nn.Flatten() comes with lot of methods and attributes

  • torch.flatten() can be used in the wild (e.g., for simple tensor OPs) whereas nn.Flatten() is expected to be used in a nn.Sequential() block as one of the layers.

  • torch.flatten() has no information about the computation graph unless it is stuck into other graph-aware block (with tensor.requires_grad flag set to True) whereas nn.Flatten() is always being tracked by autograd.

  • torch.flatten() cannot accept and process (e.g., linear/conv1D) layers as inputs whereas nn.Flatten() is mostly used for processing these neural net layers.

  • both torch.flatten() and nn.Flatten() return views to input tensor. Thus, any modification to the result also affects the input tensor. (See the code below)


Code demo:

# input tensors to work with
In [109]: t1 = torch.arange(12).reshape(3, -1)
In [110]: t2 = torch.arange(12, 24).reshape(3, -1)
In [111]: t3 = torch.arange(12, 36).reshape(3, 2, -1)   # 3D tensor

Flattening with torch.flatten():

In [113]: t1flat = torch.flatten(t1)

In [114]: t1flat
Out[114]: tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11])

# modification to the flattened tensor    
In [115]: t1flat[-1] = -1

# input tensor is also modified; thus flattening is a view.
In [116]: t1
Out[116]: 
tensor([[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, -1]])

Flattening with nn.Flatten():

In [123]: nnfl = nn.Flatten()
In [124]: t3flat = nnfl(t3)

# note that the result is 2D, as opposed to 1D with torch.flatten
In [125]: t3flat
Out[125]: 
tensor([[12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27],
        [28, 29, 30, 31, 32, 33, 34, 35]])

# modification to the result
In [126]: t3flat[-1, -1] = -1

# input tensor also modified. Thus, flattened result is a view.
In [127]: t3
Out[127]: 
tensor([[[12, 13, 14, 15],
         [16, 17, 18, 19]],

        [[20, 21, 22, 23],
         [24, 25, 26, 27]],

        [[28, 29, 30, 31],
         [32, 33, 34, -1]]])

tidbit: torch.flatten() is the precursor to nn.Flatten() and its brethren nn.Unflatten() since it existed from the very beginning. Then, there was a legitimate use-case for nn.Flatten(), since this is a common requirement for almost all ConvNets (just before the softmax or elsewhere). So it was added later on in the PR #22245.

There are also recent proposals to use nn.Flatten() in ResNets for model surgery.

like image 37
kmario23 Avatar answered Sep 30 '22 16:09

kmario23