I'm trying to add more than one generator training step per cycle to a GAN, i.e. I want my Generator to update its parameters n times every m updates of the Discriminator, where n > m.
I wrote this piece of code:
for epoch in range(num_epochs):
for batch_idx, (real, _) in enumerate(loader):
real = real.view(-1, 784).to(device)
batch_size = real.shape[0]
# Training Generator
for i in range(gen_advantage):
noise = torch.randn(batch_size, z_dim).to(device)
fake = gen(noise)
output = disc(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
lossG.backward()
opt_gen.step()
gen.zero_grad()
# Training Discriminator
for i in range(disc_advantage):
disc_real = disc(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) * 0.5
lossD.backward() # Breaks here
opt_disc.step()
disc.zero_grad()
For context, criterion
is a BCELoss
, opt_gen
and opt_disc
are optim.Adam
, disc
and gen
are my Discriminator and Generator instances and the images in the loader
are 28x28.
So, this code throws an error at the lossD.backward()
line, even if disc_advantage == 1
:
RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.
I can't get why, since in my understanding, I'm neither accessing freed tensors nor backwarding the lossD
multiple times.
Anyhow, i tried as suggested to put retain_graph=True
in the lossG.backward()
line (in the generator loop), but it throws another different error:
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 784]], which is output 0 of AsStridedBackward0, is at version 15; expected version 14 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
which I truly can't understand neither, since the error is thrown at the same line as before, i.e. lossD.backward()
.
That's it. I tried figuring it out alone by scraping the net for explanations on how PyTorch gradients work, but I only found some theoretical articles on how the gradients are computed, which although interesting, not what I needed.
So please help.
PyTorch will free the computation graph for fake
after the backward pass. But then you attempt to calculate lossD
from it. So the below is actually your main issue:
disc_fake = disc(fake).view(-1)
Instead, you have a couple of options.
Just detach it from the graph, i.e. tell PyTorch not to try to track a gradient. The reason your code fails at the lossD.backward()
is because there is no gradient to track, yet it tries. So you can instead change it to disc_fake = disc(fake.detach()).view(-1)
If for some reason you need it, but I don't think you do, you can tell PyTorch to preserve your graph from lossG
, i.e. lossG.backward(retain_graph=True)
.
I'm guessing you want option 1. Option 2 is probly not recommended since I don't think it's actually what you want, and it will increase your memory footprint.