I was wondering is there an equivalent PyTorch loss function for TensorFlow's softmax_cross_entropy_with_logits
?
is there an equivalent PyTorch loss function for TensorFlow's
softmax_cross_entropy_with_logits
?
torch.nn.functional.cross_entropy
This takes logits as inputs (performing log_softmax
internally). Here "logits" are just some values that are not probabilities (i.e. not necessarily in the interval [0,1]
).
But, logits are also the values that will be converted to probabilities.
If you consider the name of the tensorflow function you will understand it is pleonasm (since the with_logits
part assumes softmax
will be called).
In the PyTorch implementation looks like this:
loss = F.cross_entropy(x, target)
Which is equivalent to :
lp = F.log_softmax(x, dim=-1)
loss = F.nll_loss(lp, target)
It is not F.binary_cross_entropy_with_logits
because this function assumes multi label classification:
F.sigmoid + F.binary_cross_entropy = F.binary_cross_entropy_with_logits
It is not torch.nn.functional.nll_loss
either because this function takes log-probabilities (after log_softmax()
) not logits.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With