Why does a python float multiplied by a torch.long gives a torch.float but powering a float by a torch.long gives a torch.long?
>>> a = 0.9
>>> b = torch.tensor(2, dtype=torch.long)
>>> foo = a * b
>>> print(foo, foo.dtype)
tensor(1.8000) torch.float32
>>> bar = a ** b
>>> print(bar, bar.dtype)
tensor(0) torch.int64
This looks like a bug, probably in the way pytorch binds ** to __rpow__ or __pow__.
E.g. if you tried 0.9 - torch.tensor(2), since 0.9 isn't a tensor, this gets interpreted as torch.tensor(2).__rsub__(0.9), which works correctly. ** behaves the same way, but torch.tensor(2).__rpow__(0.9) incorrectly returns tensor(0) with dtype int64.
In the meantime you can use torch.tensor(0.9) ** torch.tensor(2).
Filed a bug: https://github.com/pytorch/pytorch/issues/32436
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With