I have a loss function that depends on an "exponential moving average" Z
. A minimal example (pay special attention to the getUpdatedZ
function):
import torch
import torch.nn as nn
class FeedForward(nn.Module):
def __init__(self):
super(FeedForward, self).__init__()
self.model = nn.Sequential(nn.Linear(1, 100),
nn.ReLU(),
nn.Linear(100, 1))
def forward(self, x):
return self.model(x)
model = FeedForward()
nEpochs = 100
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
def getTrainingPoints():
return torch.rand(1000, 1)
def lossFunction(X, Z):
# Returning Z here is enough to expose the problem. The real loss is more complicated.
return Z
def getUpdatedZ(X, Z):
U = model(X)
Znew = torch.mean(U)
# Having Z in this computation creates an inplace operation (I'm not sure why).
# Returning, for example, Znew, does not cause any issues (but the computation is incorrect)
return 0.2 * Z + 0.8 * Znew
Z = torch.tensor([1.0])
X = getTrainingPoints()
for i in range(nEpochs):
optimizer.zero_grad()
Z = getUpdatedZ(X, Z)
loss = lossFunction(X, Z)
# loss function depends on gradient of the model in the real version of the code, hence retain_graph=True
loss.backward(retain_graph=True)
optimizer.step()
I get the following error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [100, 1]], which is output 0 of AsStridedBackward0, is at version 2; expected version 1 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
After some trials, I think that the error arises because you are computing a recursive function (Z = getUpdatedZ(X, Z)
) but you are changing some of its parameters (the weights of the Linear
modules) at each iteration through optimizer.step()
.
You can backward()
just at the end of the for cycle, or you may want to break the autodifferentiation graph, for example by calling Z.detach()
after loss.backward()
. Sometimes this trick is used to avoid too complex and inefficient backpropagations (check, for example this).
However in both cases, this will change the structure of the optimized function, so be sure of what you are doing.