Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Pytorch - Stack dimension must be exactly the same?

In pytorch, given the tensors a of shape (1X11) and b of shape (1X11), torch.stack((a,b),0) would give me a tensor of shape (2X11)

However, when a is of shape (2X11) and b is of shape (1X11), torch.stack((a,b),0) will raise an error cf. "the two tensor size must exactly be the same".

Because the two tensor are the output of a model (gradient included), I can't convert them to numpy to use np.stack() or np.vstack().

Is there any possible solution for least GPU memory usage?

like image 605
Achaca Avatar asked May 17 '18 14:05

Achaca


People also ask

How do you concatenate two tensors of different dimensions in PyTorch?

torch.cat() is used to concatenate two or more tensors, whereas torch. stack() is used to stack the tensors. We can join the tensors in different dimensions such as 0 dimension, -1 dimension. Both torch.cat() and torch.

Does torch stack create new tensor?

Python PyTorch stack() method. PyTorch torch. stack() method joins (concatenates) a sequence of tensors (two or more tensors) along a new dimension. It inserts new dimension and concatenates the tensors along that dimension.


1 Answers

It seems you want to use torch.cat() (concatenate tensors along an existing dimension) and not torch.stack() (concatenate/stack tensors along a new dimension):

import torch

a = torch.randn(1, 42, 1, 1)
b = torch.randn(1, 42, 1, 1)

ab = torch.stack((a, b), 0)
print(ab.shape)
# torch.Size([2, 1, 42, 1, 1])

ab = torch.cat((a, b), 0)
print(ab.shape)
# torch.Size([2, 42, 1, 1])
aab = torch.cat((a, ab), 0)
print(aab.shape)
# torch.Size([3, 42, 1, 1])
like image 177
benjaminplanche Avatar answered Sep 21 '22 15:09

benjaminplanche