Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Subsetting A Pytorch Tensor Using Square-Brackets

Tags:

python

pytorch

I came across a line of code used to reduce a 3D Tensor to a 2D Tensor in PyTorch. The 3D tensor x is of size torch.Size([500, 50, 1]) and this line of code:

x = x[lengths - 1, range(len(lengths))]

was used to reduce x to a 2D tensor of size torch.Size([50, 1]). lengths is also a tensor of shape torch.Size([50]) containing values.

Please can anyone explain how this works? Thank you.

like image 645
Mofe Adeyemi Avatar asked Apr 24 '26 15:04

Mofe Adeyemi


2 Answers

After being quite stumped by the behavior, I did some more digging into this, and found that it is consistent behavior with the indexing of multi-dimensional NumPy arrays. What makes this counter-intuitive is the less obvious fact that both arrays have to have the same length, i.e. in this case len(lengths).

In fact, it works as the following: * lengths is determining the order in which you access the first dimension. I.e., if you have a 1D array a = [0, 1, 2, ...., 500], and access it with the list b = [300, 200, 100], then the result a[b] = [301, 201, 101] (This also explains the lengths - 1 operator, which simply causes the accessed values to be the same as the index used in b, or lengths, respectively). * range(len(lengths)) then *simply chooses the i-th element in the i-th row. If you have a square matrix, you can interpret this as the diagonal of the matrix. Since you only access a single element for each position along the first two dimensions, this can be stored in a single dimension (thus reducing your 3D tensor to 2D). The latter dimension is simply kept "as is".

If you want to play around with this, I strongly recommend to change the range() value to something longer/shorter, which will result in the following error:

IndexError: shape mismatch: indexing arrays could not be broadcast together with shapes (x,) (y,)

where x and y are your specific length values.

To write this accessing method out in the long form to understand what happens "under the hood", also consider the below example:

import torch
x = torch.randint(500, 50, 1)
lengths = torch.tensor([2, 30, 1, 4])  # random examples to explore
diag = list(range(len(lengths)))  # [0, 1, 2, 3]
result = []
for i, row in enumerate(lengths):
    temp_tensor = x[row, :, :]  # temp_tensor.shape = [1, 50, 1]
    temp_tensor = temp_tensor.squeeze(0)[diag[i]]  # temp_tensor.shape = [1, 1]
    result.append(temp.tensor)

# back to pytorch
result = torch.tensor(result)
result.shape  # [4, 1]
like image 129
dennlinger Avatar answered Apr 26 '26 09:04

dennlinger


The key feature here is passing values of a tensor lengths as indices for x. Here simplified example, I swaped dimensions of container, so index dimenson goes first:

container = torch.arange(0, 50 )
container = f.reshape((5, 10))
>>>tensor([[ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
        [30, 31, 32, 33, 34, 35, 36, 37, 38, 39],
        [40, 41, 42, 43, 44, 45, 46, 47, 48, 49]])

indices = torch.arange( 2, 7, dtype=torch.long )
>>>tensor([2, 3, 4, 5, 6])

print( container[ range( len(indices) ), indices] )
>>>tensor([ 2, 13, 24, 35, 46])    

Note: we got one thing from a row ( range( len(indices) ) makes sequential row numbers), with column number given by indices[ row_number ]

like image 31
Alexey Birukov Avatar answered Apr 26 '26 08:04

Alexey Birukov



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!