pythonkeraslstmcross-entropymean-square-error

Why not use mean squared error for classification problems?


I am trying to solve a simple binary classification problem using LSTM. I am trying to figure out the correct loss function for the network. The issue is, when I use the binary cross-entropy as loss function, the loss value for training and testing is relatively high as compared to using the mean squared error (MSE) function.

Upon research, I came across justifications that binary cross-entropy should be used for classification problems and MSE for the regression problem. However, in my case, I am getting better accuracies and lesser loss value with MSE for binary classification.

I am not sure how to justify these obtained results. Why not use mean squared error for classification problems?


Solution

  • I'd like to share my understanding of the MSE and binary cross-entropy functions.

    In the case of classification, we take the argmax of the probability of each training instance.

    Now, consider an example of a binary classifier where model predicts the probability as [0.49, 0.51]. In this case, the model will return 1 as the prediction.

    Now, assume that the actual label is also 1.

    In such a case, if MSE is used, it will return 0 as a loss value, whereas the binary cross-entropy will return some "tangible" value. And, if somehow with all data samples, the trained model predicts a similar type of probability, then binary cross-entropy effectively return a big accumulative loss value, whereas MSE will return a 0.

    According to the MSE, it's a perfect model, but, actually, it's not that good model, that's why we should not use MSE for classification.