I don't understand this line:
lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)
There is no comment, so is it some well-known Python (or PyTorch?) idiom? Could someone explain what it means, or show a different way that makes the intent clearer?
lprobs
is a pytorch Tensor
, and it could contain any size float type (I doubt this code is intended to support int or complex types). As far as I know, the Tensor classes don't override the __ne__
function.
It's a combination of fancy indexing with a boolean mask, and a "trick" (although intended by design) to check for NaN
: x != x
holds iff x
is NaN
(for floats, that is).
They could alternatively have written
lprobs[torch.isnan(lprobs)] = torch.tensor(-math.inf).to(lprobs)
or, probably even more idiomatically, used torch.nan_to_num
(but beware that the latter also has special behaviour towards infinities).
A non-updating variant of the above would be
torch.where(torch.isnan(lprobs), torch.tensor(-math.inf), lprobs)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With