Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

RMSE loss for multi output regression problem in PyTorch

I'm training a CNN architecture to solve a regression problem using PyTorch where my output is a tensor of 20 values. I planned to use RMSE as my loss function for the model and tried to use PyTorch's nn.MSELoss() and took the square root for it using torch.sqrt() for that but got confused after obtaining the results.I'll try my best to explain why. It's obvious that for a batch-size bs my output tensor's dimensions would be [bs , 20].I tried to implement and RMSE function of my own :

   def loss_function (predicted_x , target ):
        loss = torch.sum(torch.square(predicted_x - target) , axis= 1)/(predicted_x.size()[1]) #Taking the mean of all the squares by dividing it with the number of outputs i.e 20 in my case
        loss = torch.sqrt(loss)
        loss = torch.sum(loss)/predicted_x.size()[0]  #averaging out by batch-size
        return loss

But the output of my loss_function() and how PyTorch implements it with nn.MSELoss() differed . I'm not sure whether my implementation is wrong or am I using nn.MSELoss() in the wrong way.

like image 534
cronin Avatar asked Oct 16 '22 03:10

cronin


2 Answers

The MSE loss is the mean of the squares of the errors. You're taking the square-root after computing the MSE, so there is no way to compare your loss function's output to that of the PyTorch nn.MSELoss() function — they're computing different values.

However, you could just use the nn.MSELoss() to create your own RMSE loss function as:

loss_fn = nn.MSELoss()
RMSE_loss = torch.sqrt(loss_fn(prediction, target))
RMSE_loss.backward()

Hope that helps.

like image 153
KsEuro Avatar answered Nov 03 '22 21:11

KsEuro


To replicate the default PyTorch's MSE (Mean-squared error) loss function, you need to change your loss_function method to the following:

def loss_function (predicted_x , target ):
    loss = torch.sum(torch.square(predicted_x - target) , axis= 1)/(predicted_x.size()[1])
    loss = torch.sum(loss)/loss.shape[0]
    return loss

Here is why the above method works - MSE Loss means mean squared error loss. So you need not have to implement square root (torch.sqrt) in your code. By default, the loss in PyTorch does an average of all examples in the batch for calculating loss. Hence the second line in the method.

To implement RMSELoss and integrate into your training, you can do it look this:

class RMSELoss(torch.nn.Module):
    def __init__(self):
        super(RMSELoss,self).__init__()

    def forward(self,x,y):
        criterion = nn.MSELoss()
        loss = torch.sqrt(criterion(x, y))
        return loss

And you can call this class similar to any loss function in PyTorch.

like image 30
Sri Raghu Malireddi Avatar answered Nov 03 '22 21:11

Sri Raghu Malireddi