For example, the code is
input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)
input
is
tensor([[ 1.5742, 0.8183, -2.3005, -1.1650, -0.2451],
[ 1.0553, 0.6021, -0.4938, -1.5379, -1.2054],
[-0.1728, 0.8372, -1.9181, -0.9110, 0.2422]])
and result
is
tensor([[ 0, 2, 1, 2, 2]])
However, I want a result like this
tensor([[ 1, 0, 0, 0, 0],
[ 0, 0, 1, 0, 0],
[ 0, 1, 0, 1, 1]])
Finally, I solved it. But this solution may not be efficient. Code as follow,
input = torch.randn(3, 10)
result = torch.argmax(input, dim=0, keepdim=True)
result_0 = result == 0
result_1 = result == 1
result_2 = result == 2
result = torch.cat((result_0, result_1, result_2), 0)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With