Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

What does x[x!=x] mean?

I don't understand this line:

lprobs[lprobs != lprobs] = torch.tensor(-math.inf).to(lprobs)

There is no comment, so is it some well-known Python (or PyTorch?) idiom? Could someone explain what it means, or show a different way that makes the intent clearer?

lprobs is a pytorch Tensor, and it could contain any size float type (I doubt this code is intended to support int or complex types). As far as I know, the Tensor classes don't override the __ne__ function.

like image 827
Darren Cook Avatar asked Dec 22 '22 15:12

Darren Cook


1 Answers

It's a combination of fancy indexing with a boolean mask, and a "trick" (although intended by design) to check for NaN: x != x holds iff x is NaN (for floats, that is).

They could alternatively have written

lprobs[torch.isnan(lprobs)] = torch.tensor(-math.inf).to(lprobs)

or, probably even more idiomatically, used torch.nan_to_num (but beware that the latter also has special behaviour towards infinities).

A non-updating variant of the above would be

torch.where(torch.isnan(lprobs), torch.tensor(-math.inf), lprobs)
like image 199
phipsgabler Avatar answered Jan 02 '23 16:01

phipsgabler