Logo Questions Linux Laravel Mysql Ubuntu Git Menu

Does pytorch apply softmax automatically in nn.Linear

In pytorch a classification network model is defined as this,

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer

    def forward(self, x):
        x = F.relu(self.hidden(x))      # activation function for hidden layer
        x = self.out(x)
        return x

Is softmax applied here? In my understanding, things should be like,

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
        self.relu =  torch.nn.ReLu(inplace=True)
        self.out = torch.nn.Linear(n_hidden, n_output)   # output layer
        self.softmax = torch.nn.Softmax(dim=n_output)
    def forward(self, x):
        x = self.hidden(x)      # activation function for hidden layer
        x = self.relu(x)
        x = self.out(x)
        x = self.softmax(x)
        return x

I understand that F.relu(self.relu(x)) is also applying relu, but the first block of code doesn't apply softmax, right?

like image 574
yujuezhao Avatar asked Aug 15 '19 20:08


People also ask

How does softmax work in PyTorch?

Applies the Softmax function to an n-dimensional input Tensor rescaling them so that the elements of the n-dimensional output Tensor lie in the range [0,1] and sum to 1. When the input Tensor is a sparse tensor then the unspecifed values are treated as -inf .

Is PyTorch softmax stable?

But, softmax by itself is actually numerically stable, and also uses the max trick for numerical stability (see link below).

What does PyTorch nn linear do?

PyTorch - nn.Linear nn. Linear(n,m) is a module that creates single layer feed forward network with n inputs and m output. Mathematically, this module is designed to calculate the linear equation Ax = b where x is input, b is output, A is weight.

Do you need a softmax layer?

Short answer: Generally, you don't need to do softmax if you don't need probabilities. And using raw logits leads to more numerically stable code. Long answer: First of all, the inputs of the softmax layer are called logits.

1 Answers

Latching on to what @jodag was already saying in his comment, and extending it a bit to form a full answer:

No, PyTorch does not automatically apply softmax, and you can at any point apply torch.nn.Softmax() as you want. But, softmax has some issues with numerical stability, which we want to avoid as much as we can. One solution is to use log-softmax, but this tends to be slower than a direct computation.

Especially when we are using Negative Log Likelihood as a loss function (in PyTorch, this is torch.nn.NLLLoss, we can utilize the fact that the derivative of (log-)softmax+NLLL is actually mathematically quite nice and simple, which is why it makes sense to combine the both into a single function/element. The result is then torch.nn.CrossEntropyLoss. Again, note that this only applies directly to the last layer of your network, any other computation is not affected by any of this.

like image 94
dennlinger Avatar answered Oct 02 '22 05:10
