Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I add an element to a PyTorch tensor along a certain dimension?

I have a tensor inps, which has a size of [64, 161, 1] and I have some new data d which has a size of [64, 161]. How can I add d to inps such that the new size is [64, 161, 2]?

like image 770
Shamoon Avatar asked Apr 08 '20 13:04

Shamoon


2 Answers

There is a cleaner way by using .unsqueeze() and torch.cat(), which makes direct use of the PyTorch interface:

import torch

# create two sample vectors
inps = torch.randn([64, 161, 1])
d = torch.randn([64, 161])

# bring d into the same format, and then concatenate tensors
new_inps = torch.cat((inps, d.unsqueeze(2)), dim=-1)
print(new_inps.shape)  # [64, 161, 2]

Essentially, unsqueezing the second dimension already brings the two tensors into the same shape; you just have to be careful to unsqueeze along the right dimension. Similarly, the concatenation is unfortunately named differently from the otherwise similarly named NumPy function, but behave the same. Note that instead of letting torch.cat figure out the dimension by providing dim=-1, you can also explicitly provide the dimension to concatenate along, in this case by replacing it with dim=2.

Keep in mind the difference between concatenation and stacking, which is helpful for similar problems with tensor dimensions.

like image 200
dennlinger Avatar answered Sep 18 '22 23:09

dennlinger


You have to first reshape d so that it has a third dimension along which concatenation becomes possible. After it has a third dimension and the two tensors have the same number of dimensions, then you can use torch.cat((inps, d),2) to stack them.

old_shape = tuple(d.shape)
new_shape = old_shape + (1,)
inps_new = torch.cat( (inps, d.view( new_shape ), 2)
like image 32
Conor Avatar answered Sep 20 '22 23:09

Conor