Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Understanding the order when reshaping a tensor

For a tensor:

x = torch.tensor([
    [
        [[0.4495, 0.2356],
          [0.4069, 0.2361],
          [0.4224, 0.2362]],
                   
         [[0.4357, 0.6762],
          [0.4370, 0.6779],
          [0.4406, 0.6663]]
    ],    
    [
        [[0.5796, 0.4047],
          [0.5655, 0.4080],
          [0.5431, 0.4035]],
         
         [[0.5338, 0.6255],
          [0.5335, 0.6266],
          [0.5204, 0.6396]]
    ]
])

Firstly would like to split it into 2 (x.shape[0]) tensors then concat them. Here, i dont really have to actually split it as long as i get the correct output, but it makes a lot more sense to me visually to split it then concat them back together.

For example:

# the shape of the splits are always the same
split1 = torch.tensor([
    [[0.4495, 0.2356],
    [0.4069, 0.2361],
    [0.4224, 0.2362]],

    [[0.4357, 0.6762],
    [0.4370, 0.6779],
    [0.4406, 0.6663]]
])
split2 = torch.tensor([
    [[0.5796, 0.4047],
    [0.5655, 0.4080],
    [0.5431, 0.4035]],

    [[0.5338, 0.6255],
    [0.5335, 0.6266],
    [0.5204, 0.6396]]
])

split1 = torch.cat((split1[0], split1[1]), dim=1)
split2 = torch.cat((split2[0], split2[1]), dim=1)
what_i_want = torch.cat((split1, split2), dim=0).reshape(x.shape[0], split1.shape[0], split1.shape[1])

enter image description here

For the above result, i thought directly reshaping x.reshape([2, 3, 4]) would work, it resulted in the correct dimension but incorrect result.

In general i am:

  1. not sure how to split the tensor into x.shape[0] tensors.
  2. confused about how reshape works. Most of the time i am able to get the dimension right, but the order of the numbers are always incorrect.

Thank you

like image 454
tom c Avatar asked Apr 29 '26 15:04

tom c


1 Answers

The order of the elements in memory in python, pytorch, numpy, c++ etc. are in row-major ordering:

[ first, second
  third, forth  ]

While in matlab, fortran, etc. the order is column major:

[ first,  third
  second, fourth ]

For higher dimensional tensors, this means elements are ordered from the last dimension to the first.

You can easily visualize it using torch.arange followed by .view:

a = torch.arange(24)
a.view(2,3,4)

Results with

tensor([[[ 0,  1,  2,  3],
    [ 4,  5,  6,  7],
    [ 8,  9, 10, 11]],

   [[12, 13, 14, 15],
    [16, 17, 18, 19],
    [20, 21, 22, 23]]])

As you can see the elements are ordered first by row (last dimension), then by column, and finally by the first dimension.

When you reshape a tensor, you do not change the underlying order of the elements, only the shape of the tensor. However, if you permute a tensor - you change the underlying order of the elements.

Look at the difference between a.view(3,2,4) and a.permute(0,1,2) - the shape of the resulting two tensors is the same, but not the ordering of elements:

In []: a.view(3,2,4)
Out[]:
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7]],

        [[ 8,  9, 10, 11],
         [12, 13, 14, 15]],

        [[16, 17, 18, 19],
         [20, 21, 22, 23]]])

In []: a.permute(1,0,2)
Out[]:
tensor([[[ 0,  1,  2,  3],
         [12, 13, 14, 15]],

        [[ 4,  5,  6,  7],
         [16, 17, 18, 19]],

        [[ 8,  9, 10, 11],
         [20, 21, 22, 23]]])
like image 139
Shai Avatar answered May 02 '26 04:05

Shai