I am currently facing a problem.
I am wondering how I should back propagate the loss functions in the following model.
What is important here is that all the blue part is common to the 2 outputs, the green part is a binary classification using BCELoss and the red part is a regression task using MSELoss.
Here are the code and image to understand the model fully :
class First_branch(nn.Module) :
def __init__(self, input_size, num_heads=3):
super(First_branch, self).__init__()
self.fc1 = nn.Linear(input_size*num_heads, 128)
self.fc2 = nn.Linear(128, 64)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
class Final_branch_1(nn.Module):
def __init__(self, input_size, num_heads=3):
super(Final_branch_1, self).__init__()
self.fc1 = nn.Linear(64, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x)) # Binary classification
return x
class Final_branch_2(nn.Module):
def __init__(self, input_size, num_heads=3):
super(Final_branch_2, self).__init__()
self.fc1 = nn.Linear(64, 128)
self.fc2 = nn.Linear(128, 64)
self.fc3 = nn.Linear(64, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = self.fc3(x) #Regression task
return x
class Model(nn.Module):
def __init__(self, in_feats_dict, hidden_feats_dict, out_feats_dict, edge_feats_dict, rel_names):
super().__init__()
self.conv = ConvModel(in_feats_dict, hidden_feats_dict, out_feats_dict, edge_feats_dict, rel_names, num_heads=3)
self.embed = First_branch(sum(out_feats_dict.values()), num_heads=3)
self.pred_1 = Final_branch_1(64, num_heads=3)
self.pred_2 = Final_branch_2(64, num_heads=3)
def forward(self, g, node_features, edge_features):
conv_output = self.conv(g, node_features, edge_features)
to_concat = [conv_output[key] for key in conv_output.keys()]
# Aggregate the results following each latent feature
aggregated_features = [torch.mean(i, dim=0) for i in to_concat]
concatenated_features = torch.cat(aggregated_features)
embedded_features = self.embed(concatenated_features)
prediction_1 = self.pred_1(embedded_features)
prediction_2 = self.pred_2(embedded_features)
return prediction_1, prediction_2
The question is the following, what is the best way to compute the loss for those task?
I know that with 2 MSELoss, the best way is just to make a sum of the loss:
loss_1 = criterion(output_1, score_1)
loss_2 = criterion(output_2, score_2)
loss = loss_1 + loss_2
loss.backward()
optimizer.step()
But since the first branch is doing a classification, I need to use the BCELoss and the sum between a BCELoss and a MSELoss object no longer work.
For the moment, my solution is :
bce = nn.BCELoss()
mse = MSELoss()
loss_1 = bce(output_1, score_1)
loss_2 = mse(output_2, score_2)
loss_1.backward(retain_graph=True)
loss_2.backward()
optimizer.step()
But is it the good way? Because it seems that my model has trouble learning when I make it this way.
I'm not sure what you mean by the sum between a BCELoss and a MSELoss object no longer work
. You can still sum the losses. Summing the losses is numerically equivalent to calling backward on each loss individually:
import torch
import torch.nn as nn
# simple dummy model
class Model(nn.Module):
def __init__(self):
super().__init__()
self.stem = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 32),
nn.ReLU()
)
self.head1 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
self.head2 = nn.Sequential(
nn.Linear(32, 32),
nn.ReLU(),
nn.Linear(32, 1),
)
def forward(self, x):
x = self.stem(x)
p1 = torch.sigmoid(self.head1(x))
p2 = self.head2(x)
return p1, p2
# inputs/outputs
x = torch.randn(4, 32)
y_regression = torch.randn(4, 1)
y_classification = (torch.randn(4,1)>0).float()
model = Model()
bce = nn.BCELoss()
mse = nn.MSELoss()
# compute predictions, backward combined loss
p1, p2 = model(x)
loss1 = bce(p1, y_classification)
loss2 = mse(p2, y_regression)
loss = loss1 + loss2
loss.backward()
# grab gradient
g1 = model.stem[0].weight.grad.data.clone()
# zero gradients
for param in model.parameters():
param.grad = None
# compute predictions, backward separate losses
p1, p2 = model(x)
loss1 = bce(p1, y_classification)
loss1.backward(retain_graph=True)
loss2 = mse(p2, y_regression)
loss2.backward()
# grab gradient
g2 = model.stem[0].weight.grad.data.clone()
# validate gradients are the same
torch.allclose(g1, g2)
When you backward
multiple times, pytorch accumulates and sums the gradients of the individual terms together.
The second method of calling backward
on the individual loss terms is slightly more memory efficient, but otherwise equivalent. If the model isn't learning well, you may need to add weights to the different loss terms. As-is, pytorch sums the gradients from each loss based on their raw numerical value. If one loss term is much larger than the other, it will contribute more to the overall gradient. This can be corrected by scaling the larger term to be more in-line with the smaller term (or vice/versa).
Also as a minor note, your First_branch
ends with a linear layer and your Final_branch
starts with a linear layer, which is redundant. You'll get a slight bump from ending your First_branch
with a ReLU.
You should also consider using a combined head for the outputs. You can have a single MLP output a tensor of shape (bs, 2)
where the first column is routed to the classification loss and the second to the regression loss. In theory this allows the two output types to share information - you'll have to experiment to see if it makes a difference.