Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is One-Hot Encoding required for using PyTorch's Cross Entropy Loss Function?

For example, if I want to solve the MNIST classification problem, we have 10 output classes. With PyTorch, I would like to use the torch.nn.CrossEntropyLoss function. Do I have to format the targets so that they are one-hot encoded or can I simply use their class labels that come with the dataset?

like image 764
Loay Sharaky Avatar asked Jun 18 '20 18:06

Loay Sharaky


People also ask

Does cross entropy loss require one hot encoding?

CrossEntropyLoss expects class indices and does not take one-hot encoded tensors as target labels.

How do you make one hot encoding in PyTorch?

Creating PyTorch one-hot encoding In the above example, we try to implement the one hot() encoding function as shown here first; we import all required packages, such as a torch. After that, we created a tenor with different values, and finally, we applied the one hot() function as shown.

Does cross entropy loss apply softmax?

Categorical cross-entropy loss is closely related to the softmax function, since it's practically only used with networks with a softmax layer at the output.


1 Answers

nn.CrossEntropyLoss expects integer labels. What it does internally is that it doesn't end up one-hot encoding the class label at all, but uses the label to index into the output probability vector to calculate the loss should you decide to use this class as the final label. This small but important detail makes computing the loss easier and is the equivalent operation to performing one-hot encoding, measuring the output loss per output neuron as every value in the output layer would be zero with the exception of the neuron indexed at the target class. Therefore, there's no need to one-hot encode your data if you have the labels already provided.

The documentation has some more insight on this: https://pytorch.org/docs/master/generated/torch.nn.CrossEntropyLoss.html. In the documentation you'll see targets which serves as part of the input parameters. These are your labels and they are described as:

Targets

This clearly shows how the input should be shaped and what is expected. If you in fact wanted to one-hot encode your data, you would need to use torch.nn.functional.one_hot. To best replicate what the cross entropy loss is doing under the hood, you'd also need nn.functional.log_softmax as the final output and you'd have to additionally write your own loss layer since none of the PyTorch layers use log softmax inputs and one-hot encoded targets. However, nn.CrossEntropyLoss combines both of these operations together and is preferred if your outputs are simply class labels so there is no need to do the conversion.

like image 182
rayryeng Avatar answered Oct 01 '22 20:10

rayryeng