Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch equivalence for softmax_cross_entropy_with_logits

I was wondering is there an equivalent PyTorch loss function for TensorFlow's softmax_cross_entropy_with_logits?

like image 239
Dark_Voyager Avatar asked Sep 14 '17 12:09

Dark_Voyager


1 Answers

is there an equivalent PyTorch loss function for TensorFlow's softmax_cross_entropy_with_logits?

torch.nn.functional.cross_entropy

This takes logits as inputs (performing log_softmax internally). Here "logits" are just some values that are not probabilities (i.e. not necessarily in the interval [0,1]).

But, logits are also the values that will be converted to probabilities. If you consider the name of the tensorflow function you will understand it is pleonasm (since the with_logits part assumes softmax will be called).

In the PyTorch implementation looks like this:

loss = F.cross_entropy(x, target)

Which is equivalent to :

lp = F.log_softmax(x, dim=-1)
loss = F.nll_loss(lp, target)

It is not F.binary_cross_entropy_with_logits because this function assumes multi label classification:

F.sigmoid + F.binary_cross_entropy = F.binary_cross_entropy_with_logits

It is not torch.nn.functional.nll_loss either because this function takes log-probabilities (after log_softmax()) not logits.

like image 199
prosti Avatar answered Oct 12 '22 12:10

prosti