Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

pytorch variable index lost one dimension

I am going to get every horizontal tensor in a variable but I got one dimension lost.

This is my code:

import torch
from torch.autograd import Variable
t = torch.rand((2,2,4))
x = Variable(t)
print(x)
shape = x.size()
for i in range(shape[0]):
    for j in range(shape[1]):
        print(x[i,j])

and the output is :

Variable containing:
(0 ,.,.) =
  0.6717  0.8216  0.5100  0.9106
  0.3280  0.8182  0.5781  0.3919

(1 ,.,.) =
  0.8823  0.4237  0.6620  0.0817
  0.5781  0.4187  0.3769  0.0498
[torch.FloatTensor of size 2x2x4]

Variable containing:
 0.6717
 0.8216
 0.5100
 0.9106
[torch.FloatTensor of size 4]

Variable containing:
 0.3280
 0.8182
 0.5781
 0.3919
[torch.FloatTensor of size 4]

Variable containing:
 0.8823
 0.4237
 0.6620
 0.0817
[torch.FloatTensor of size 4]

Variable containing:
 0.5781
 0.4187
 0.3769
 0.0498
[torch.FloatTensor of size 4]

and how can I get [torch.FloatTensor of size 1x4]?

like image 485
qimeng wang Avatar asked Feb 02 '18 06:02

qimeng wang


Video Answer


2 Answers

In your case, x is a 2x2x4 tensor. So when you do x[0] you obtain the 2x4 tensor which is in the first row. And if you do x[i,j] you obtain the 4 dimensional vector in position (i,j). If you want to retain one of the dimensions you can either use a slice: x[i,j:j+1] OR reshape the tensor: x[i,j].view(1,4). Thus your code would look like:

import torch
from torch.autograd import Variable
t = torch.rand((2,2,4))
x = Variable(t)
print(x)
shape = x.size()
for i in range(shape[0]):
    for j in range(shape[1]):
        print(x[i,j:j+1])

or

import torch
from torch.autograd import Variable
t = torch.rand((2,2,4))
x = Variable(t)
print(x)
shape = x.size()
for i in range(shape[0]):
    for j in range(shape[1]):
        print(x[i,j].view(1,4)

Will give you the desired result.

Edit:

Yes, or as mentioned in the answer by nnnmmm, torch.unsqueeze(x[i, j], 0) also works as it adds a dimension of size 1 in the 0th position.

like image 91
patapouf_ai Avatar answered Nov 11 '22 00:11

patapouf_ai


Try torch.unsqueeze(x[i, j], 0).

like image 32
nnnmmm Avatar answered Nov 10 '22 23:11

nnnmmm