Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I flatten a tensor in pytorch?

Tags:

pytorch

Given a tensor of multiple dimensions, how do I flatten it so that it has a single dimension?

Eg:

>>> t = torch.rand([2, 3, 5]) >>> t.shape torch.Size([2, 3, 5]) 

How do I flatten it to have shape:

torch.Size([30]) 
like image 627
Tom Hale Avatar asked Apr 06 '19 07:04

Tom Hale


People also ask

How do you flatten a tensor in Python?

A tensor can be flattened into a one-dimensional tensor by reshaping it using the method torch. flatten(). This method supports both real and complex-valued input tensors. It takes a torch tensor as its input and returns a torch tensor flattened into one dimension.

What does flattening a tensor mean?

Flatten a tensor A flatten operation on a tensor reshapes the tensor to have a shape that is equal to the number of elements contained in the tensor. This is the same thing as a 1d-array of elements. Flattening a tensor means to remove all of the dimensions except for one.

What is flatten layer PyTorch?

PyTorch Flatten is used to reshape any tensor with different dimensions to a single dimension so that we can do further operations on the same input data. The shape of the tensor will be the same as that of the number of elements in the tensor.

How do you flatten a 2D tensor?

1 Answer. You can use either flatten() or reshape() to convert a 2D tensor into a 1D tensor.


1 Answers

TL;DR: torch.flatten()

Use torch.flatten() which was introduced in v0.4.1 and documented in v1.0rc1:

>>> t = torch.tensor([[[1, 2],                        [3, 4]],                       [[5, 6],                        [7, 8]]]) >>> torch.flatten(t) tensor([1, 2, 3, 4, 5, 6, 7, 8]) >>> torch.flatten(t, start_dim=1) tensor([[1, 2, 3, 4],         [5, 6, 7, 8]]) 

For v0.4.1 and earlier, use t.reshape(-1).


With t.reshape(-1):

If the requested view is contiguous in memory this will equivalent to t.view(-1) and memory will not be copied.

Otherwise it will be equivalent to t.contiguous().view(-1).


Other non-options:

  • t.view(-1) won't copy memory, but may not work depending on original size and stride

  • t.resize(-1) gives RuntimeError (see below)

  • t.resize(t.numel()) warning about being a low-level method (see discussion below)

(Note: pytorch's reshape() may change data but numpy's reshape() won't.)


t.resize(t.numel()) needs some discussion. The torch.Tensor.resize_ documentation says:

The storage is reinterpreted as C-contiguous, ignoring the current strides (unless the target size equals the current size, in which case the tensor is left unchanged)

Given the current strides will be ignored with the new (1, numel()) size, the order of the elements may apppear in a different order than with reshape(-1). However, "size" may mean the memory size, rather than the tensor's size.

It would be nice if t.resize(-1) worked for both convenience and efficiency, but with torch 1.0.1.post2, t = torch.rand([2, 3, 5]); t.resize(-1) gives:

RuntimeError: requested resize to -1 (-1 elements in total), but the given  tensor has a size of 2x2 (4 elements). autograd's resize can only change the  shape of a given tensor, while preserving the number of elements. 

I raised a feature request for this here, but the consensus was that resize() was a low level method, and reshape() should be used in preference.

like image 66
Tom Hale Avatar answered Oct 06 '22 10:10

Tom Hale