I have two tensors: scores and lists
scores is of shape (x, 8) and lists of (x, 8, 4). I want to filter the max values for each row in scores and filter the respective elements from lists.
Take the following as an example (shape dimension 8 was reduced to 2 for simplicity):
scores = torch.tensor([[0.5, 0.4], [0.3, 0.8], ...])
lists = torch.tensor([[[0.2, 0.3, 0.1, 0.5],
[0.4, 0.7, 0.8, 0.2]],
[[0.1, 0.2, 0.1, 0.3],
[0.4, 0.3, 0.2, 0.5]], ...])
Then I would like to filter these tensors to:
scores = torch.tensor([0.5, 0.8, ...])
lists = torch.tensor([[0.2, 0.3, 0.1, 0.5], [0.4, 0.3, 0.2, 0.5], ...])
NOTE: I tried so far, to retrieve the indices from the original score vector and use it as an index vector to filter lists:
# PSEUDO-CODE
indices = scores.argmax(dim=1)
for list, idx in zip(lists, indices):
list = list[idx]
That is also where the question name is coming from.
I imagine you tried something like
indices = scores.argmax(dim=1)
selection = lists[:, indices]
This does not work because the indices are selected for every element in dimension 0, so the final shape is (x, x, 4).
The perform the correct selection you need to replace the slice with a range.
indices = scores.argmax(dim=1)
selection = lists[range(indices.size(0)), indices]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With