Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch tensor advanced indexing

Say I have one matrix and one vector as follows:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

is there a way to slice it x[y] so the result is:

res = [1, 6, 8]

So basically I take the first element of y and take the element in x that corresponds to the first row and the elements' column.

like image 383
Dr. Prof. Patrick Avatar asked Feb 03 '23 15:02

Dr. Prof. Patrick


1 Answers

You can specify the corresponding row index as:

import torch
x = torch.tensor([[1, 2, 3],
                  [4, 5, 6],
                  [7, 8, 9]])

y = torch.tensor([0, 2, 1])

x[range(x.shape[0]), y]
tensor([1, 6, 8])
like image 106
FBruzzesi Avatar answered Feb 06 '23 15:02

FBruzzesi