Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

PyTorch LogSoftmax vs Softmax for CrossEntropyLoss

I understand that PyTorch's LogSoftmax function is basically just a more numerically stable way to compute Log(Softmax(x)). Softmax lets you convert the output from a Linear layer into a categorical probability distribution.

The pytorch documentation says that CrossEntropyLoss combines nn.LogSoftmax() and nn.NLLLoss() in one single class.

Looking at NLLLoss, I'm still confused...Are there 2 logs being used? I think of negative log as information content of an event. (As in entropy)

After a bit more looking, I think that NLLLoss assumes that you're actually passing in log probabilities instead of just probabilities. Is this correct? It's kind of weird if so...

like image 387
JacKeown Avatar asked Dec 08 '20 03:12

JacKeown


People also ask

Is LogSoftmax better than softmax?

Answer: Log Softmax is advantageous over softmax for numerical stability, optimisation and heavy penalisation for highly incorrect class. Question: Can you explain it in detail ? Penalises Larger error: The log-softmax penalty has a exponential nature compared to the linear penalisation of softmax.

What is LogSoftmax in PyTorch?

I understand that PyTorch's LogSoftmax function is basically just a more numerically stable way to compute Log(Softmax(x)) . Softmax lets you convert the output from a Linear layer into a categorical probability distribution. The pytorch documentation says that CrossEntropyLoss combines nn. LogSoftmax() and nn.

Is PyTorch softmax stable?

But, softmax by itself is actually numerically stable, and also uses the max trick for numerical stability (see link below).

Does cross-entropy loss include softmax?

Categorical cross-entropy loss is closely related to the softmax function, since it's practically only used with networks with a softmax layer at the output.


1 Answers

Yes, NLLLoss takes log-probabilities (log(softmax(x))) as input. Why?. Because if you add a nn.LogSoftmax (or F.log_softmax) as the final layer of your model's output, you can easily get the probabilities using torch.exp(output), and in order to get cross-entropy loss, you can directly use nn.NLLLoss. Of course, log-softmax is more stable as you said.

And, there is only one log (it's in nn.LogSoftmax). There is no log in nn.NLLLoss.

nn.CrossEntropyLoss() combines nn.LogSoftmax() (log(softmax(x))) and nn.NLLLoss() in one single class. Therefore, the output from the network that is passed into nn.CrossEntropyLoss needs to be the raw output of the network (called logits), not the output of the softmax function.

like image 129
kHarshit Avatar answered Sep 29 '22 12:09

kHarshit