deep-learningpytorchgenerative-adversarial-networkdcgan

Two questions on DCGAN: data normalization and fake/real batch


I am analyzing a meta-learning class that uses DCGAN + Reptile within the image generation.

I have two questions about this code.

First question: why during DCGAN training (line 74)

training_batch = torch.cat ([real_batch, fake_batch])

is a training_batch made up of real examples (real_batch) and fake examples (fake_batch) created? Why is training done by mixing real and false images? I have seen many DCGANs, but never with training done in this way.

The second question: why is the normalize_data function (line 49) and the unnormalize_data function (line 55) used during training?

def normalize_data(data):
    data *= 2
    data -= 1
    return data


def unnormalize_data(data):
    data += 1
    data /= 2
    return data

The project uses the Mnist dataset, if I wanted to use a color dataset like CIFAR10, do I have to modify those normalizations?


Solution

  • Training GANs involves giving the discriminator real and fake examples. Usually, you will see that they are given in two separate occasions. By default torch.cat concatenates the tensors on the first dimension (dim=0), which is the batch dimensions. Therefore it just doubled the batch size, where the first half are the real images and the second half the fake images.

    To calculate the loss, they adapt the targets, such that the first half (original batch size) is classified as real, and the second half is classified as fake. From initialize_gan:

    self.discriminator_targets = torch.tensor([1] * self.batch_size + [-1] * self.batch_size, dtype=torch.float, device=device).view(-1, 1)
    

    Images are represented with float values between [0, 1]. The normalisation changes that to produce values between [-1, 1]. GANs generally use tanh in the generator, therefore the fake images have values between [-1, 1], hence the real images should be in the same range, otherwise it would be trivial for the discriminator to distinguish the fake images from the real ones.

    If you want to display these images, you need to unnormalise them first, i.e. convert them to values between [0, 1].

    The project uses the Mnist dataset, if I wanted to use a color dataset like CIFAR10, do I have to modify those normalizations?

    No, you don't need to change them, because images in colour also have their values between [0, 1], there are simply more values, representing the 3 channels (RGB).