I'm trying to train a two-stage model in an end-to-end way. However, I want to update the different stages of models with different losses. For example, suppose the end to end model is composed of two models:model1 and model2. The output is calculated through running
features = model1(inputs)
output = model2(features)
I want to update the parameters of model1 with loss1, while keeping the parameter of model2 unchanged. Next, I want to update the parameters of model2 with loss2, while keeping the parameter of model1 unchanged. My full implementation is something like:
import torch
import torch.nn as nn
# Define the first model
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Linear(20, 10)
self.conv2 = nn.Linear(10, 5)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
# Define the second model
class Net1(nn.Module):
def __init__(self):
super(Net1, self).__init__()
self.conv1 = nn.Linear(5, 1)
def forward(self, x):
x = self.conv1(x)
return x
# Initialize models
model1 = Net()
model2 = Net1()
# Initialize separate optimizers for each model
optimizer = torch.optim.SGD(model1.parameters(), lr=0.1)
optimizer1 = torch.optim.SGD(model2.parameters(), lr=0.1)
optimizer.zero_grad()
optimizer1.zero_grad()
criterion = nn.CrossEntropyLoss()
# Sample inputs and labels
inputs = torch.randn(2, 20)
labels = torch.randn(2,1)
features = model1(inputs)
outputs_model = model2(features)
loss1 = criterion(outputs_model[0], labels[0])
loss2 = criterion(outputs_model, labels)
loss1.backward(retain_graph=True)
optimizer.step()
optimizer.zero_grad()
optimizer1.zero_grad()
loss2.backward()
However, this will return
Traceback (most recent call last):
File , line 55, in <module>
loss2.backward()
^^^^^^^^^^^^^^^^
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/torch/_tensor.py", line 521, in backward
torch.autograd.backward(
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/torch/autograd/__init__.py", line 289, in backward
_engine_run_backward(
File "/opt/homebrew/anaconda3/lib/python3.11/site-packages/torch/autograd/graph.py", line 769, in _engine_run_backward
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [10, 5]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
I kinda understand why this is happening, but is there a way to address this?
I have a solution that requires an extra forward pass from model2
. If someone can think of a solution that doesn't require this, feel free to chime in.
First, a bit about backprop. When you call backward
on a tensor, pytorch computes gradients for all tensors in the computational graph. Since loss1
and loss2
are both output from model2
, calling backward
on either results in computing gradients for all parameters in both model1
and model2
.
When these gradients are computed, pytorch uses current parameter values and activations stored from the forward pass. This means that the parameters used in the backward computation need to be unchanged from the forward pass.
When you call optimizer.step()
, you update all parameters in the optimizer with an in-place update. This is why you get the error you see.
loss1.backward(retain_graph=True) # computes gradients
optimizer.step() # updates parameters in place
...
loss2.backward() # throws an error due to the in-place update from optimizer step
Because of this, the step
calls must happen after all the backward
calls. This leads to an issue of how to control gradients. The following does not work:
# trying to zero grads fails
loss1.backward() # backward first loss
optimizer1.zero_grad() # zero model2 grads
loss2.backward() # loss 2 grads computes for model1
# changing the loss order results in the same
I think there are two ways around this:
step
(hacky, probably breaks a bunch of stuff)I opted for the second option
import torch
import torch.nn as nn
torch.manual_seed(42)
class Net1(nn.Module):
def __init__(self):
super(Net1, self).__init__()
self.conv1 = nn.Linear(20, 10)
self.conv2 = nn.Linear(10, 5)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
# Define the second model
class Net2(nn.Module):
def __init__(self):
super(Net2, self).__init__()
self.conv1 = nn.Linear(5, 1)
def forward(self, x):
x = self.conv1(x)
return x
# Initialize models
model1 = Net1()
model2 = Net2()
# Initialize separate optimizers for each model
opt1 = torch.optim.SGD(model1.parameters(), lr=0.1)
opt2 = torch.optim.SGD(model2.parameters(), lr=0.1)
p1_1 = next(model1.parameters()).data.clone()
p2_1 = next(model2.parameters()).data.clone()
criterion = nn.MSELoss()
inputs = torch.randn(2, 20)
labels = torch.randn(2,1)
features = model1(inputs)
out1 = model2(features)
out2 = model2(features.detach()) # detach removes from the computational graph
opt1.zero_grad()
opt2.zero_grad()
# update only model1 with loss1
loss1 = criterion(out1[0], labels[0])
loss1.backward()
opt2.zero_grad()
opt1.step()
# check parameters after update
p1_2 = next(model1.parameters()).data.clone()
p2_2 = next(model2.parameters()).data.clone()
assert not (p1_1 == p1_2).any() # all parameters from model1 updated
assert (p2_1 == p2_2).all() # no parameters from model2 updated
# update only model2 with loss2
loss2 = criterion(out2, labels)
loss2.backward()
opt2.step()
# check parameters after update
p1_3 = next(model1.parameters()).data.clone()
p2_3 = next(model2.parameters()).data.clone()
assert (p1_2 == p1_3).all() # no parameters from model1 updated
assert not (p2_2 == p2_3).any() # all parameters from model2 updated
Note that I changed your loss from cross entropy to MSE. Your model produces an output of size (bs, 1)
while cross entropy expects an output of shape (bs, num_classes)
. The output of shape (bs, 1)
implies you are computing cross entropy of a single class, which will always return 0 loss
pred = torch.randn(64, 1)
labels = torch.randn(64, 1)
criterion = nn.CrossEntropyLoss()
criterion(pred, labels)
> tensor(-0.)