I am trying to solve a simple binary classification problem using LSTM. I am trying to figure out the correct loss function for the network. The issue is, when I use the binary cross-entropy as loss function, the loss value for training and testing is relatively high as compared to using the mean squared error (MSE) function.
Upon research, I came across justifications that binary cross-entropy should be used for classification problems and MSE for the regression problem. However, in my case, I am getting better accuracies and lesser loss value with MSE for binary classification.
I am not sure how to justify these obtained results. Why not use mean squared error for classification problems?
Cross-entropy (or softmax loss, but cross-entropy works better) is a better measure than MSE for classification, because the decision boundary in a classification task is large (in comparison with regression).
One of the main reasons why MSE doesn't work with logistic regression is when the MSE loss function is plotted with respect to weights of the logistic regression model, the curve obtained is not a convex curve which makes it very difficult to find the global minimum.
In statistics and signal processing, a minimum mean square error (MMSE) estimator is an estimation method which minimizes the mean square error (MSE), which is a common measure of estimator quality, of the fitted values of a dependent variable.
RMSE evaluation is weak and should not be used for multi-classification as well as regression with scattered results. To teach the basic concepts of classification and regression, “RMSE Evaluation” is usually used as a common evaluation method.
I would like to show it using an example. Assume a 6 class classification problem.
Assume, True probabilities = [1, 0, 0, 0, 0, 0]
Case 1: Predicted probabilities = [0.2, 0.16, 0.16, 0.16, 0.16, 0.16]
Case 2: Predicted probabilities = [0.4, 0.5, 0.1, 0, 0, 0]
The MSE in the Case1 and Case 2 is 0.128 and 0.1033 respectively.
Although, Case 1 is correctly predicting class 1 for the instance, the loss in Case 1 is higher than the loss in Case 2.
Though @nerd21 gives a good example for "MSE as loss function is bad for 6-class classification", it's not the same for binary classification.
Let's just consider binary classification. Label is [1, 0]
, one prediction is h1=[p, 1-p]
, another prediction is h2=[q, 1-q]
, thus their's MSEs are:
L1 = 2*(1-p)^2, L2 = 2*(1-q)^2
Assuming h1 is mis-classifcation, i.e. p<1-p
, thus 0<p<0.5
Assuming h2 is correct-classification, i.e. q>1-q
, thus 0.5<q<1
Then L1-L2=2(p-q)(p+q-2) > 0
is for sure:
p < q
is for sure;
q + q < 1 + 0.5 < 1.5
, thus p + q - 2 < -0.5 < 0
;
thus L1-L2>0
, i.e. L1 > L2
This mean for binary classfication with MSE as loss function, mis-classification will definitely with larger loss that correct-classification.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With