Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch Tensors - vectorized slicing with given list of end indices

Tags:

python

pytorch

Suppose I have a 1D PyTorch tensor end_index of length L.

I want to construct a 2D PyTorch tensor T with L lines where T[i,j] = 2 when j < end_index[i] and T[i,j] = 1 otherwise.

The following works:

T = torch.ones([4,3], dtype=torch.long)
for element in end_index:
    T[:, :element] = 2

Is there a vectorizd way to do this?

like image 358
PhysicsPrincess Avatar asked Nov 04 '25 08:11

PhysicsPrincess


1 Answers

You can construct such a tensor using broadcast semantics

# sample inputs
L, C = 4, 3
end_index = torch.tensor([0, 2, 2, 1])

# Construct tensor of shape [L, C] such that for all (i, j)
#     T[i, j] = 2 if j < end_index[i] else 1
j_range = torch.arange(C, device=end_index.device)
T = (j_range[None, :] < end_index[:, None]).long() + 1

which results in

T = 
tensor([[1, 1, 1],
        [2, 2, 1],
        [2, 2, 1],
        [2, 1, 1]])
like image 152
jodag Avatar answered Nov 07 '25 02:11

jodag



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!