I came across a line of code used to reduce a 3D Tensor to a 2D Tensor in PyTorch. The 3D tensor x is of size torch.Size([500, 50, 1]) and this line of code:
x = x[lengths - 1, range(len(lengths))]
was used to reduce x to a 2D tensor of size torch.Size([50, 1]). lengths is also a tensor of shape torch.Size([50]) containing values.
Please can anyone explain how this works? Thank you.
After being quite stumped by the behavior, I did some more digging into this, and found that it is consistent behavior with the indexing of multi-dimensional NumPy arrays. What makes this counter-intuitive is the less obvious fact that both arrays have to have the same length, i.e. in this case len(lengths).
In fact, it works as the following:
* lengths is determining the order in which you access the first dimension. I.e., if you have a 1D array a = [0, 1, 2, ...., 500], and access it with the list b = [300, 200, 100], then the result a[b] = [301, 201, 101] (This also explains the lengths - 1 operator, which simply causes the accessed values to be the same as the index used in b, or lengths, respectively).
* range(len(lengths)) then *simply chooses the i-th element in the i-th row. If you have a square matrix, you can interpret this as the diagonal of the matrix. Since you only access a single element for each position along the first two dimensions, this can be stored in a single dimension (thus reducing your 3D tensor to 2D). The latter dimension is simply kept "as is".
If you want to play around with this, I strongly recommend to change the range() value to something longer/shorter, which will result in the following error:
IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (x,) (y,)
where x and y are your specific length values.
To write this accessing method out in the long form to understand what happens "under the hood", also consider the below example:
import torch
x = torch.randint(500, 50, 1)
lengths = torch.tensor([2, 30, 1, 4]) # random examples to explore
diag = list(range(len(lengths))) # [0, 1, 2, 3]
result = []
for i, row in enumerate(lengths):
temp_tensor = x[row, :, :] # temp_tensor.shape = [1, 50, 1]
temp_tensor = temp_tensor.squeeze(0)[diag[i]] # temp_tensor.shape = [1, 1]
result.append(temp.tensor)
# back to pytorch
result = torch.tensor(result)
result.shape # [4, 1]
The key feature here is passing values of a tensor lengths as indices for x.
Here simplified example, I swaped dimensions of container, so index dimenson goes first:
container = torch.arange(0, 50 )
container = f.reshape((5, 10))
>>>tensor([[ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9],
[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
[30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
[40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])
indices = torch.arange( 2, 7, dtype=torch.long )
>>>tensor([2, 3, 4, 5, 6])
print( container[ range( len(indices) ), indices] )
>>>tensor([ 2, 13, 24, 35, 46])
Note: we got one thing from a row ( range( len(indices) ) makes sequential row numbers), with column number given by indices[ row_number ]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With