pythondeep-learningpytorchloss-function

Multi loss going into the same subsquent model using PyTorch


I am currently facing a problem.

I am wondering how I should back propagate the loss functions in the following model.

What is important here is that all the blue part is common to the 2 outputs, the green part is a binary classification using BCELoss and the red part is a regression task using MSELoss.

Here are the code and image to understand the model fully : Model visualization

class First_branch(nn.Module) : 
    def __init__(self, input_size, num_heads=3):
        super(First_branch, self).__init__()
        self.fc1 = nn.Linear(input_size*num_heads, 128)
        self.fc2 = nn.Linear(128, 64)
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

class Final_branch_1(nn.Module):
    def __init__(self, input_size, num_heads=3):
        super(Final_branch_1, self).__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x)) # Binary classification
        return x

class Final_branch_2(nn.Module):
    def __init__(self, input_size, num_heads=3):
        super(Final_branch_2, self).__init__()
        self.fc1 = nn.Linear(64, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x) #Regression task
        return x

class Model(nn.Module):
    def __init__(self, in_feats_dict, hidden_feats_dict, out_feats_dict, edge_feats_dict, rel_names):
        super().__init__()
        self.conv = ConvModel(in_feats_dict, hidden_feats_dict, out_feats_dict, edge_feats_dict, rel_names, num_heads=3)
        self.embed = First_branch(sum(out_feats_dict.values()), num_heads=3)
    
        self.pred_1 = Final_branch_1(64, num_heads=3)
        self.pred_2 = Final_branch_2(64, num_heads=3)

    def forward(self, g, node_features, edge_features):
        conv_output = self.conv(g, node_features, edge_features)
    
        to_concat = [conv_output[key] for key in conv_output.keys()]
    
        # Aggregate the results following each latent feature 
        aggregated_features = [torch.mean(i, dim=0) for i in to_concat]           
        
        concatenated_features = torch.cat(aggregated_features)
    
        embedded_features = self.embed(concatenated_features)
    
        prediction_1 = self.pred_1(embedded_features)
        prediction_2 = self.pred_2(embedded_features)
    
        return prediction_1, prediction_2

The question is the following, what is the best way to compute the loss for those task?

I know that with 2 MSELoss, the best way is just to make a sum of the loss:

loss_1 = criterion(output_1, score_1)
loss_2 = criterion(output_2, score_2)
loss = loss_1 + loss_2
loss.backward()
optimizer.step()

But since the first branch is doing a classification, I need to use the BCELoss and the sum between a BCELoss and a MSELoss object no longer work.

For the moment, my solution is :

bce = nn.BCELoss()
mse = MSELoss()
loss_1 = bce(output_1, score_1)
loss_2 = mse(output_2, score_2)
loss_1.backward(retain_graph=True)
loss_2.backward()
optimizer.step()

But is it the good way? Because it seems that my model has trouble learning when I make it this way.


Solution

  • I'm not sure what you mean by the sum between a BCELoss and a MSELoss object no longer work. You can still sum the losses. Summing the losses is numerically equivalent to calling backward on each loss individually:

    import torch
    import torch.nn as nn
    
    # simple dummy model
    class Model(nn.Module):
        def __init__(self):
            super().__init__()
            
            self.stem = nn.Sequential(
                nn.Linear(32, 32),
                nn.ReLU(),
                nn.Linear(32, 32),
                nn.ReLU()
            )
            
            self.head1 = nn.Sequential(
                nn.Linear(32, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
            )
            
            self.head2 = nn.Sequential(
                nn.Linear(32, 32),
                nn.ReLU(),
                nn.Linear(32, 1),
            )
            
        def forward(self, x):
            x = self.stem(x)
            p1 = torch.sigmoid(self.head1(x))
            p2 = self.head2(x)
            
            return p1, p2
    
    
    # inputs/outputs
    x = torch.randn(4, 32)
    y_regression = torch.randn(4, 1)
    y_classification = (torch.randn(4,1)>0).float()
        
    model = Model()
    
    bce = nn.BCELoss()
    mse = nn.MSELoss()
    
    
    # compute predictions, backward combined loss
    p1, p2 = model(x)
    loss1 = bce(p1, y_classification)
    loss2 = mse(p2, y_regression)
    loss = loss1 + loss2
    loss.backward()
    
    # grab gradient
    g1 = model.stem[0].weight.grad.data.clone()
    
    # zero gradients
    for param in model.parameters():
        param.grad = None
    
    # compute predictions, backward separate losses
    p1, p2 = model(x)
    loss1 = bce(p1, y_classification)
    loss1.backward(retain_graph=True)
    loss2 = mse(p2, y_regression)
    loss2.backward()
    
    # grab gradient
    g2 = model.stem[0].weight.grad.data.clone()
    
    # validate gradients are the same
    torch.allclose(g1, g2)
    

    When you backward multiple times, pytorch accumulates and sums the gradients of the individual terms together.

    The second method of calling backward on the individual loss terms is slightly more memory efficient, but otherwise equivalent. If the model isn't learning well, you may need to add weights to the different loss terms. As-is, pytorch sums the gradients from each loss based on their raw numerical value. If one loss term is much larger than the other, it will contribute more to the overall gradient. This can be corrected by scaling the larger term to be more in-line with the smaller term (or vice/versa).

    Also as a minor note, your First_branch ends with a linear layer and your Final_branch starts with a linear layer, which is redundant. You'll get a slight bump from ending your First_branch with a ReLU.

    You should also consider using a combined head for the outputs. You can have a single MLP output a tensor of shape (bs, 2) where the first column is routed to the classification loss and the second to the regression loss. In theory this allows the two output types to share information - you'll have to experiment to see if it makes a difference.