Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Concat tensors in PyTorch

I have a tensor called data of the shape [128, 4, 150, 150] where 128 is the batch size, 4 is the number of channels, and the last 2 dimensions are height and width. I have another tensor called fake of the shape [128, 1, 150, 150].

I want to drop the last list/array from the 2nd dimension of data; the shape of data would now be [128, 3, 150, 150]; and concatenate it with fake giving the output dimension of the concatenation as [128, 4, 150, 150].

Basically, in other words, I want to concatenate the first 3 dimensions of data with fake to give a 4-dimensional tensor.

I am using PyTorch and came across the functions torch.cat() and torch.stack()

Here is a sample code I've written:

fake_combined = []
        for j in range(batch_size):
            fake_combined.append(torch.stack((data[j][0].to(device), data[j][1].to(device), data[j][2].to(device), fake[j][0].to(device))))
fake_combined = torch.tensor(fake_combined, dtype=torch.float32)
fake_combined = fake_combined.to(device)

But I am getting an error in the line:

fake_combined = torch.tensor(fake_combined, dtype=torch.float32)

The error is:

ValueError: only one element tensors can be converted to Python scalars

Also, if I print the shape of fake_combined, I get the output as [128,] instead of [128, 4, 150, 150]

And when I print the shape of fake_combined[0], I get the output as [4, 150, 150], which is as expected.

So my question is, why am I not able to convert the list to tensor using torch.tensor(). Am I missing something? Is there any better way to do what I intend to do?

Any help will be appreciated! Thanks!

like image 373
ntd Avatar asked Feb 16 '19 21:02

ntd


2 Answers

@rollthedice32 's answer works perfectly fine. For educational purposes, here's using torch.cat

a = torch.rand(128, 4, 150, 150)
b = torch.rand(128, 1, 150, 150)

# Cut out last dimension
a = a[:, :3, :, :]
# Concatenate in 2nd dimension
result = torch.cat([a, b], dim=1)
print(result.shape)
# => torch.Size([128, 4, 150, 150])
like image 89
Coolness Avatar answered Nov 05 '22 13:11

Coolness


You could also just assign to that particular dimension.

orig = torch.randint(low=0, high=10, size=(2,3,2,2))
fake = torch.randint(low=111, high=119, size=(2,1,2,2))
orig[:,[2],:,:] = fake

Original Before

tensor([[[[0, 1],
      [8, 0]],

     [[4, 9],
      [6, 1]],

     [[8, 2],
      [7, 6]]],


    [[[1, 1],
      [8, 5]],

     [[5, 0],
      [8, 6]],

     [[5, 5],
      [2, 8]]]])

Fake

tensor([[[[117, 115],
      [114, 111]]],


    [[[115, 115],
      [118, 115]]]])

Original After

tensor([[[[  0,   1],
      [  8,   0]],

     [[  4,   9],
      [  6,   1]],

     [[117, 115],
      [114, 111]]],


    [[[  1,   1],
      [  8,   5]],

     [[  5,   0],
      [  8,   6]],

     [[115, 115],
      [118, 115]]]])

Hope this helps! :)

like image 44
kabrapankaj32 Avatar answered Nov 05 '22 11:11

kabrapankaj32