pythonmachine-learningpytorch

Training different stage of model with different loss


I'm trying to train a two-stage model in an end-to-end way. However, I want to update the different stages of models with different losses. For example, suppose the end to end model is composed of two models:model1 and model2. The output is calculated through running

features = model1(inputs)
output = model2(features)

I want to update the parameters of model1 with loss1, while keeping the parameter of model2 unchanged. Next, I want to update the parameters of model2 with loss2, while keeping the parameter of model1 unchanged. My full implementation is something like:

import torch
import torch.nn as nn

# Define the first model
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Linear(20, 10)
        self.conv2 = nn.Linear(10, 5)

    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        return x

# Define the second model
class Net1(nn.Module):
    def __init__(self):
        super(Net1, self).__init__()
        self.conv1 = nn.Linear(5, 1)

    def forward(self, x):
        x = self.conv1(x)
        return x

# Initialize models
model1 = Net()
model2 = Net1()

# Initialize separate optimizers for each model
optimizer = torch.optim.SGD(model1.parameters(), lr=0.1)
optimizer1 = torch.optim.SGD(model2.parameters(), lr=0.1)

optimizer.zero_grad() 
optimizer1.zero_grad()

criterion = nn.CrossEntropyLoss()

# Sample inputs and labels
inputs = torch.randn(2, 20)
labels = torch.randn(2,1)

features = model1(inputs)         
outputs_model = model2(features) 

loss1 = criterion(outputs_model[0], labels[0]) 
loss2 = criterion(outputs_model, labels) 
   
loss1.backward(retain_graph=True)  
optimizer.step()  
optimizer.zero_grad()       
optimizer1.zero_grad()  

 
loss2.backward()        

However, this will return

Traceback (most recent call last):
  File , line 55, in <module>
    loss2.backward()        
    ^^^^^^^^^^^^^^^^
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/torch/_tensor.py", line 521, in backward
    torch.autograd.backward(
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py", line 289, in backward
    _engine_run_backward(
  File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 5]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I kinda understand why this is happening, but is there a way to address this?


Solution

  • I have a solution that requires an extra forward pass from model2. If someone can think of a solution that doesn't require this, feel free to chime in.

    First, a bit about backprop. When you call backward on a tensor, pytorch computes gradients for all tensors in the computational graph. Since loss1 and loss2 are both output from model2, calling backward on either results in computing gradients for all parameters in both model1 and model2.

    When these gradients are computed, pytorch uses current parameter values and activations stored from the forward pass. This means that the parameters used in the backward computation need to be unchanged from the forward pass.

    When you call optimizer.step(), you update all parameters in the optimizer with an in-place update. This is why you get the error you see.

    loss1.backward(retain_graph=True) # computes gradients
    optimizer.step() # updates parameters in place
    ...
     
    loss2.backward() # throws an error due to the in-place update from optimizer step
    

    Because of this, the step calls must happen after all the backward calls. This leads to an issue of how to control gradients. The following does not work:

    # trying to zero grads fails
    loss1.backward() # backward first loss
    optimizer1.zero_grad() # zero model2 grads
    loss2.backward() # loss 2 grads computes for model1
    
    # changing the loss order results in the same
    

    I think there are two ways around this:

    1. Cache grads outside pytorch parameters and manually re-add them to parameters before calling step (hacky, probably breaks a bunch of stuff)
    2. Do an extra forward pass (inefficient from compute, but works within Pytorch's structure)

    I opted for the second option

    import torch
    import torch.nn as nn
    
    torch.manual_seed(42)
    
    class Net1(nn.Module):
        def __init__(self):
            super(Net1, self).__init__()
            self.conv1 = nn.Linear(20, 10)
            self.conv2 = nn.Linear(10, 5)
    
        def forward(self, x):
            x = self.conv1(x)
            x = self.conv2(x)
            return x
    
    # Define the second model
    class Net2(nn.Module):
        def __init__(self):
            super(Net2, self).__init__()
            self.conv1 = nn.Linear(5, 1)
    
        def forward(self, x):
            x = self.conv1(x)
            return x
    
    # Initialize models
    model1 = Net1()
    model2 = Net2()
    
    # Initialize separate optimizers for each model
    opt1 = torch.optim.SGD(model1.parameters(), lr=0.1)
    opt2 = torch.optim.SGD(model2.parameters(), lr=0.1)
    
    p1_1 = next(model1.parameters()).data.clone()
    p2_1 = next(model2.parameters()).data.clone()
    
    criterion = nn.MSELoss()
    
    inputs = torch.randn(2, 20)
    labels = torch.randn(2,1)
    
    features = model1(inputs)
    out1 = model2(features)
    out2 = model2(features.detach()) # detach removes from the computational graph
    
    opt1.zero_grad()
    opt2.zero_grad()
    
    # update only model1 with loss1
    loss1 = criterion(out1[0], labels[0]) 
    loss1.backward()
    opt2.zero_grad()
    opt1.step()
    
    # check parameters after update
    p1_2 = next(model1.parameters()).data.clone()
    p2_2 = next(model2.parameters()).data.clone()
    assert not (p1_1 == p1_2).any() # all parameters from model1 updated
    assert (p2_1 == p2_2).all() # no parameters from model2 updated
    
    # update only model2 with loss2
    loss2 = criterion(out2, labels)
    loss2.backward()
    opt2.step()
    
    # check parameters after update
    p1_3 = next(model1.parameters()).data.clone()
    p2_3 = next(model2.parameters()).data.clone()
    assert (p1_2 == p1_3).all() # no parameters from model1 updated
    assert not (p2_2 == p2_3).any() # all parameters from model2 updated
    

    Note that I changed your loss from cross entropy to MSE. Your model produces an output of size (bs, 1) while cross entropy expects an output of shape (bs, num_classes). The output of shape (bs, 1) implies you are computing cross entropy of a single class, which will always return 0 loss

    pred = torch.randn(64, 1)
    labels = torch.randn(64, 1)
    criterion = nn.CrossEntropyLoss()
    criterion(pred, labels)
    > tensor(-0.)