Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multiplying and powering python float and pytorch integer

Tags:

python

pytorch

Why does a python float multiplied by a torch.long gives a torch.float but powering a float by a torch.long gives a torch.long?

>>> a = 0.9
>>> b = torch.tensor(2, dtype=torch.long)

>>> foo = a * b
>>> print(foo, foo.dtype)
tensor(1.8000) torch.float32

>>> bar = a ** b 
>>> print(bar, bar.dtype)
tensor(0) torch.int64
like image 624
Milad Avatar asked May 11 '26 21:05

Milad


1 Answers

This looks like a bug, probably in the way pytorch binds ** to __rpow__ or __pow__.

E.g. if you tried 0.9 - torch.tensor(2), since 0.9 isn't a tensor, this gets interpreted as torch.tensor(2).__rsub__(0.9), which works correctly. ** behaves the same way, but torch.tensor(2).__rpow__(0.9) incorrectly returns tensor(0) with dtype int64.

In the meantime you can use torch.tensor(0.9) ** torch.tensor(2).

Filed a bug: https://github.com/pytorch/pytorch/issues/32436

like image 59
nairbv Avatar answered May 13 '26 12:05

nairbv