Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch flatten doesn't maintain batch size

Tags:

python

pytorch

In Keras, using the Flatten() layer retains the batch size. For eg, if the input shape to Flatten is (32, 100, 100), in Keras output of Flatten is (32, 10000), but in PyTorch it is 320000. Why is it so?

like image 685
Nagabhushan S N Avatar asked Feb 07 '20 14:02

Nagabhushan S N


People also ask

How do I flatten batch Pytorch?

You can use torch. flatten() or Tensor. flatten() with start_dim=1 to start the flattening operation after the batch dimension.

How does Pytorch flatten work?

Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened. The order of elements in input is unchanged.

What does NN flatten () do?

nn. Flatten flattens all dimensions starting from the second dimension (index 1) by default. You can see this behaviour in the default values of the start_dim and end_dim arguments.

What does Batch mean in Pytorch?

Batch size is a term used in machine learning and refers to the number of training examples utilized in one iteration. If this is right than 100 training data should be loaded in one iteration.


2 Answers

As OP already pointed out in their answer, the tensor operations do not default to considering a batch dimension. You can use torch.flatten() or Tensor.flatten() with start_dim=1 to start the flattening operation after the batch dimension.

Alternatively since PyTorch 1.2.0 you can define an nn.Flatten() layer in your model which defaults to start_dim=1.

like image 119
jodag Avatar answered Oct 21 '22 06:10

jodag


Yes, As mentioned in this thread, PyTorch operations such as Flatten, view, reshape.

In general when using modules like Conv2d, you don't need to worry about batch size. PyTorch takes care of it. But when dealing directly with tensors, you need to take care of batch size.

In Keras, Flatten() is a layer. But in PyTorch, flatten() is an operation on the tensor. Hence, batch size needs to be taken care manually.

like image 20
Nagabhushan S N Avatar answered Oct 21 '22 05:10

Nagabhushan S N