I have two tensors: scores
and lists
scores
is of shape (x, 8)
and lists
of (x, 8, 4)
. I want to filter the max values for each row in scores
and filter the respective elements from lists
.
Take the following as an example (shape dimension 8
was reduced to 2
for simplicity):
scores = torch.tensor([[0.5, 0.4], [0.3, 0.8], ...])
lists = torch.tensor([[[0.2, 0.3, 0.1, 0.5],
[0.4, 0.7, 0.8, 0.2]],
[[0.1, 0.2, 0.1, 0.3],
[0.4, 0.3, 0.2, 0.5]], ...])
Then I would like to filter these tensors to:
scores = torch.tensor([0.5, 0.8, ...])
lists = torch.tensor([[0.2, 0.3, 0.1, 0.5], [0.4, 0.3, 0.2, 0.5], ...])
NOTE: I tried so far, to retrieve the indices from the original score vector and use it as an index vector to filter lists:
# PSEUDO-CODE
indices = scores.argmax(dim=1)
for list, idx in zip(lists, indices):
list = list[idx]
That is also where the question name is coming from.
I imagine you tried something like
indices = scores.argmax(dim=1)
selection = lists[:, indices]
This does not work because the indices are selected for every element in dimension 0, so the final shape is (x, x, 4)
.
The perform the correct selection you need to replace the slice with a range.
indices = scores.argmax(dim=1)
selection = lists[range(indices.size(0)), indices]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With