pythonpytorchgenerative-adversarial-network

Understanding PyTorch gradients and backward function, when backwarding more than once


I'm trying to add more than one generator training step per cycle to a GAN, i.e. I want my Generator to update its parameters n times every m updates of the Discriminator, where n > m.

I wrote this piece of code:

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        # Training Generator
        for i in range(gen_advantage):
            noise = torch.randn(batch_size, z_dim).to(device)
            fake = gen(noise)
            output = disc(fake).view(-1)
            lossG = criterion(output, torch.ones_like(output))
            lossG.backward()
            opt_gen.step()
            gen.zero_grad()

        # Training Discriminator
        for i in range(disc_advantage):
            disc_real = disc(real).view(-1)
            lossD_real = criterion(disc_real, torch.ones_like(disc_real))
            disc_fake = disc(fake).view(-1)
            lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
            lossD = (lossD_real + lossD_fake) * 0.5
            lossD.backward() # Breaks here
            opt_disc.step()
            disc.zero_grad()

For context, criterion is a BCELoss, opt_gen and opt_disc are optim.Adam, disc and gen are my Discriminator and Generator instances and the images in the loader are 28x28.

So, this code throws an error at the lossD.backward() line, even if disc_advantage == 1:

RuntimeError: Trying to backward through the graph a second time (or directly access saved tensors after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved tensors after calling backward.

I can't get why, since in my understanding, I'm neither accessing freed tensors nor backwarding the lossD multiple times. Anyhow, i tried as suggested to put retain_graph=True in the lossG.backward() line (in the generator loop), but it throws another different error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [512, 784]], which is output 0 of AsStridedBackward0, is at version 15; expected version 14 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

which I truly can't understand neither, since the error is thrown at the same line as before, i.e. lossD.backward().

That's it. I tried figuring it out alone by scraping the net for explanations on how PyTorch gradients work, but I only found some theoretical articles on how the gradients are computed, which although interesting, not what I needed.

So please help.


Solution

  • PyTorch will free the computation graph for fake after the backward pass. But then you attempt to calculate lossD from it. So the below is actually your main issue:

    disc_fake = disc(fake).view(-1)
    

    Instead, you have a couple of options.

    1. Just detach it from the graph, i.e. tell PyTorch not to try to track a gradient. The reason your code fails at the lossD.backward() is because there is no gradient to track, yet it tries. So you can instead change it to disc_fake = disc(fake.detach()).view(-1)

    2. If for some reason you need it, but I don't think you do, you can tell PyTorch to preserve your graph from lossG, i.e. lossG.backward(retain_graph=True).

    I'm guessing you want option 1. Option 2 is probly not recommended since I don't think it's actually what you want, and it will increase your memory footprint.