Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use the BCELoss in PyTorch?

I want to write a simple autoencoder in PyTorch and use BCELoss, however, I get NaN out, since it expects the targets to be between 0 and 1. Could someone post a simple use case of BCELoss?

like image 558
Qubix Avatar asked Apr 30 '17 16:04

Qubix


1 Answers

Update

The BCELoss function did not use to be numerically stable. See this issue https://github.com/pytorch/pytorch/issues/751. However, this issue has been resolved with Pull #1792, so that BCELoss is numerically stable now!


Old answer

If you build PyTorch from source, you can use the numerically stable function BCEWithLogitsLoss(contributed in https://github.com/pytorch/pytorch/pull/1792), which takes logits as input.

Otherwise, you can use the following function (contributed by yzgao in the above issue):

class StableBCELoss(nn.modules.Module):
       def __init__(self):
             super(StableBCELoss, self).__init__()
       def forward(self, input, target):
             neg_abs = - input.abs()
             loss = input.clamp(min=0) - input * target + (1 + neg_abs.exp()).log()
             return loss.mean()
like image 101
cheezer Avatar answered Sep 19 '22 09:09

cheezer