Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

how to expand the dimensions of a tensor in pytorch

i'm a newcomer for pytorch. if i have a tensor like that:

A = torch.tensor([[1, 2, 3], [ 4, 5, 6]]),

but my question is how to get a 2 dimensions tensor like:

B =  Tensor([[[1, 2, 3],
                           [4, 5, 6]], 

                          [[1, 2, 3], 
                           [4, 5, 6]]])
like image 452
Pengfei.C Avatar asked Dec 19 '25 08:12

Pengfei.C


2 Answers

You can concatenate ...

A
tensor([[[1., 2., 3.],
         [4., 5., 6.]]])
B = torch.cat((a, a))

B
tensor([[[1., 2., 3.],
         [4., 5., 6.]],

        [[1., 2., 3.],
         [4., 5., 6.]]])
like image 61
Bhupen Avatar answered Dec 20 '25 21:12

Bhupen


Just use the repeat function like this

B = A.repeat(2, 1, 1)
like image 34
Girish Hegde Avatar answered Dec 20 '25 21:12

Girish Hegde