Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

torch.cat but create a new dimension

I would like to concatenate tensors, not along a dimension, but by creating a new dimension.

For example:

x = torch.randn(2, 3)
x.shape # (2, 3)

torch.cat([x,x,x,x], 0).shape # (8, 3)
# This concats along dim 0, not what I want

torch.cat([x,x,x,x], -1).shape # (2, 10)
# This concats along dim 1, not what I want

torch.cat([x[None, :, :],x[None, :, :],x[None, :, :],x[None, :, :]], 0).shape 
# => (4, 2, 3)
# This is what I want, but unwieldy

Is there a simpler way?

like image 833
Benjamin Crouzier Avatar asked Aug 30 '19 13:08

Benjamin Crouzier


1 Answers

Just use torch.stack:

torch.stack([x,x,x,x]).shape # (4, 2, 3)
like image 91
Benjamin Crouzier Avatar answered Sep 22 '22 09:09

Benjamin Crouzier