Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

No N-dimensional tranpose in PyTorch

Tags:

PyTorch's torch.transpose function only transposes 2D inputs. Documentation is here.

On the other hand, Tensorflow's tf.transpose function allows you to transpose a tensor of N arbitrary dimensions.

Can someone please explain why PyTorch does not/cannot have N-dimension transpose functionality? Is this due to the dynamic nature of the computation graph construction in PyTorch versus Tensorflow's Define-then-Run paradigm?

like image 586
jhuang Avatar asked Jun 30 '17 08:06

jhuang


People also ask

What is Permute PyTorch?

Returns a view of the original tensor input with its dimensions permuted. input (Tensor) – the input tensor.

How do I flatten in PyTorch?

flatten. Flattens input by reshaping it into a one-dimensional tensor. If start_dim or end_dim are passed, only dimensions starting with start_dim and ending with end_dim are flattened.


Video Answer


1 Answers

It's simply called differently in pytorch. torch.Tensor.permute will allow you to swap dimensions in pytorch like tf.transpose does in TensorFlow.

As an example of how you'd convert a 4D image tensor from NHWC to NCHW (not tested, so might contain bugs):

>>> img_nhwc = torch.randn(10, 480, 640, 3) >>> img_nhwc.size() torch.Size([10, 480, 640, 3]) >>> img_nchw = img_nhwc.permute(0, 3, 1, 2) >>> img_nchw.size() torch.Size([10, 3, 480, 640]) 
like image 169
etarion Avatar answered Oct 14 '22 20:10

etarion