Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How can I retrieve elements in a multidimensional pytorch tensor by a list of indices?

Tags:

pytorch

I have two tensors: scores and lists
scores is of shape (x, 8) and lists of (x, 8, 4). I want to filter the max values for each row in scores and filter the respective elements from lists.

Take the following as an example (shape dimension 8 was reduced to 2 for simplicity):

scores = torch.tensor([[0.5, 0.4], [0.3, 0.8], ...])
lists = torch.tensor([[[0.2, 0.3, 0.1, 0.5],
                       [0.4, 0.7, 0.8, 0.2]], 
                      [[0.1, 0.2, 0.1, 0.3], 
                       [0.4, 0.3, 0.2, 0.5]], ...])

Then I would like to filter these tensors to:

scores = torch.tensor([0.5, 0.8, ...])
lists = torch.tensor([[0.2, 0.3, 0.1, 0.5], [0.4, 0.3, 0.2, 0.5], ...])

NOTE: I tried so far, to retrieve the indices from the original score vector and use it as an index vector to filter lists:

# PSEUDO-CODE
indices = scores.argmax(dim=1)
for list, idx in zip(lists, indices):
    list = list[idx]

That is also where the question name is coming from.

like image 657
c0mr4t Avatar asked Sep 15 '25 21:09

c0mr4t


1 Answers

I imagine you tried something like

indices = scores.argmax(dim=1)
selection = lists[:, indices]

This does not work because the indices are selected for every element in dimension 0, so the final shape is (x, x, 4).

The perform the correct selection you need to replace the slice with a range.

indices = scores.argmax(dim=1)
selection = lists[range(indices.size(0)), indices]
like image 200
LGrementieri Avatar answered Sep 17 '25 18:09

LGrementieri



Donate For Us

If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!