deep-learningone-to-manymnistgenerative-adversarial-networkcgan

Generator is not learning in One to Many CGANs


I am still new to GANs and trying to implement the One to Many CGAN, on the MNIST Dataset, to generate a sequence of images (with the number of output images equal to the number of generators) from a sum. For example, A 4-generators 1-discriminator CGANS model will generate 4 digit number images: [3, 2, 6] if the inputs are [noise, noise, noise, noise] and condition = 11.

However, I encountered a problem where the Discriminator's loss kept decreasing to near zero, but the Generators' increased. Therefore, the Generators produce noises instead of meaningful images.

I supposed this is due to the fact that the Discriminator was too good, so I added Dropout layers and decreased the number of filters but nothing changed.

I am trying to get this model to work properly as I described above.

Custom Dataset: This dataset consists of X (data_size, digit_num, channels, height, width) with digit_num representing the number of digits in the input sequence. And Y which is total possible outcomes (digit_num * 9 + 1).

digit_num = 3
label_num = 9 * digit_num + 1
data_size = 120000

dataset = SumMNISTDataset(
    "mnist",
    0,
    digit_num,
    data_size,
    transforms.Compose([transforms.Grayscale(), 
                        transforms.Normalize(127.5, 127.5)]),
)

dataloader = DataLoader(dataset, batch_size, True, drop_last=True)

Generator and Discriminator Architecture: The Model in the paper uses a deeper network for a high resolution dataset. In my case, I scale it down for the MNIST dataset.

class Generator(nn.Module):
def __init__(self, latent_dim, filter_num, label_num, embed_num=50, bias=False):
    super().__init__()
    self.pre_main = nn.Sequential(
        # 7 x 7 x 128
        nn.ConvTranspose2d(latent_dim, filter_num * 4, 7, 1, 0, bias=bias),
        nn.BatchNorm2d(filter_num * 4),
        nn.LeakyReLU(0.2),
    )
    self.condition = nn.Sequential(
        # 1 x 50
        nn.Embedding(label_num, embed_num),
        nn.Linear(embed_num, 49, bias=bias),
    )
    self.main = nn.Sequential(
        # 14 x 14 x 64
        nn.ConvTranspose2d(filter_num * 4 + 1, filter_num * 2, 4, 2, 1, bias=bias),
        nn.BatchNorm2d(filter_num * 2),
        nn.LeakyReLU(0.2),
        # 28 x 28 x 1
        nn.ConvTranspose2d(filter_num * 2, 1, 4, 2, 1, bias=bias),
        nn.Tanh(),
    )

def forward(self, x, y):
    y = self.condition(y).reshape(-1, 1, 7, 7)
    x = self.pre_main(x)
    x = torch.cat((x, y), dim=1)
    x = self.main(x)
    return x

class Discriminator(nn.Module):
def __init__(self, filter_num, label_num, embed_num=50, bias=True):
    super().__init__()
    self.condition = nn.Sequential(
        # 28 x 28 x 50
        nn.Embedding(label_num, embed_num),
        nn.Linear(embed_num, 28 * 28, bias=bias),
    )
    self.main = nn.Sequential(
        # 14 x 14 x 64
        nn.Conv2d(2, filter_num, 3, 2, 1, bias=bias),
        nn.BatchNorm2d(filter_num),
        nn.LeakyReLU(0.2),
        # 7 x 7 x 128
        nn.Conv2d(filter_num, filter_num * 2, 3, 2, 1, bias=bias),
        nn.BatchNorm2d(filter_num * 2),
        nn.LeakyReLU(0.2),
        # Dense
        nn.Flatten(),
        nn.Linear(7 * 7 * filter_num * 2, 1, bias=bias),
    )

def forward(self, x, y):
    y = self.condition(y).reshape(-1, 1, 28, 28)
    x = torch.cat((x, y), dim=1)
    x = self.main(x)
    return x

Initializing: The number of generators is equal to the digits in the input sequence. Theoretically, if the number of generators equals to 1, than this model is equivalent to CGANs, however, the model still failed to converge even when there is 1 generator.

learning_rate = 0.0002
beta_1 = 0.5
latent_dim = 100
filter_num = 32
generator_num = digit_num
omega = 1 / generator_num


def weight_ini_G(model):
    if type(model) == nn.Linear:
        nn.init.constant_(model.weight.data, 1 / generator_num)
    elif type(model) == nn.BatchNorm2d:
        nn.init.constant_(model.weight.data, 1 / generator_num)
        nn.init.constant_(model.bias.data, 0)


def weight_ini_D(model):
    if type(model) == nn.Linear:
        nn.init.normal_(model.weight.data, 0.0, 0.2)
    elif type(model) == nn.BatchNorm2d:
        nn.init.normal_(model.weight.data, 1.0, 0.2)
        nn.init.constant_(model.bias.data, 0)


Gs = [
   Generator(latent_dim, filter_num,  label_num).to(device).apply(weight_ini_G)
   for _ in range(generator_num)
]
D = Discriminator(filter_num, label_num).to(device).apply(weight_ini_D)

G_optimizers = [
    optim.Adam(G.parameters(), learning_rate, betas=(beta_1, 0.999)) for G in Gs
]
D_optimizer = optim.Adam(D.parameters(), learning_rate, betas=(beta_1, 0.999))

bce = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()

Helper Functions: Just to clarify, the generate_hybrid function will calculate the mean of all images on the digit_num axis for training.

def generate_fake():
    rand_labels = torch.randint(0, label_num, (batch_size, 1), device=device)
    images = [
        Gs[g](torch.randn((batch_size, latent_dim, 1, 1), device=device), rand_labels)
        for g in range(generator_num)
    ]
    images = torch.stack(images, axis=1).detach_()

    return images, rand_labels


def generate_real():
    images, labels = next(iter(dataloader))
    return images.to(device), labels.to(device)


def generate_hybrid(images):
    if images.shape[0] == digit_num:
        images = torch.mean(images, dim=0)
    elif images.shape[1] == digit_num:
        images = torch.mean(images, dim=1)

    return images

Update the Generator

def update_generators(real, fake):
    ones = torch.ones((batch_size, 1), device=device)
    f_images, f_labels = fake
    r_images, _ = real
    total_loss = 0

    for g in range(generator_num):
        hybrid_fake = generate_hybrid(f_images)
        # r_image = r_images[:, g, :, :]

        preds = D(hybrid_fake, f_labels)

        bce_loss = bce(preds, ones)
        # l1_loss = l1(f_image, r_image)
        loss = bce_loss

        Gs[g].zero_grad()
        loss.backward()
        G_optimizers[g].step()

        total_loss += loss.item()

    return total_loss / generator_num

Update the Discriminator

def update_discriminator(real, fake):
    half_batch_size = batch_size // 2
    zeros = torch.zeros((half_batch_size, 1), device=device)
    ones = torch.ones((half_batch_size, 1), device=device)
    f_images, f_labels = fake
    r_images, r_labels = real

    f_images = f_images[:half_batch_size]
    f_labels = f_labels[:half_batch_size]
    r_images = r_images[:half_batch_size]
    r_labels = r_labels[:half_batch_size]

    total_loss = 0

    # Train on Real
    hybrid_real = generate_hybrid(r_images)
    real_preds = D(hybrid_real, r_labels)

    bce_r_loss = bce(real_preds, ones)
    D.zero_grad()
    bce_r_loss.backward()

    # Train of Fake
    hybrid_fake = generate_hybrid(f_images)
    fake_preds = D(hybrid_fake, f_labels)

    bce_f_loss = bce(fake_preds, zeros)
    bce_f_loss.backward()
    D_optimizer.step()

    total_loss = (bce_f_loss.item() + bce_r_loss.item()) / 2

    return total_loss

Training the Model

D_losses = []
G_losses = []
epochs = 5
fixed_noise = torch.randn((4, latent_dim, 1, 1), device=device)
fixed_label = torch.randint(0, label_num, (4,), device=device)

for epoch in range(epochs):
    print(f"Epoch {epoch + 1}:")
    for batch in range(data_size // batch_size):
        # Generate Fake Images
        fake = generate_fake()

        # Generate Real Images
        real = generate_real()

        D_loss = update_discriminator(real, fake)

        fake = generate_fake()
        G_loss = update_generators(real, fake)

        if batch % 100 == 0:
            print(
                f"[Batch: {(batch + 1) * batch_size :7d}/{data_size}  D_Loss: {D_loss}  G_Loss: {G_loss}]"
            )
            generate_image(epoch, batch, fixed_noise, fixed_label)

        D_losses.append(D_loss)
        G_losses.append(G_loss)

I stopped the model at about 1 epoch due to meaningless results that it displayed.

Generators Loss Discriminator Loss

Here is my code.


Solution

  • Issues

    One of the pain points in training a model is that they become powerful quickly that it is they just learn without processing the information. I would suggest few things.

    1. Adjusting Learning Rates By decreasing the learning rate of the Discriminator relative to the Generators. This will slow down the Discriminator's training, giving the Generators more opportunity to learn. For example:
    # Adjusted learning rates
    D_optimizer = optim.Adam(D.parameters(), lr=learning_rate * 0.1, betas=(beta_1, 0.999))
    G_optimizers = [
        optim.Adam(G.parameters(), lr=learning_rate * 2, betas=(beta_1, 0.999)) for G_ops in Gopss
    ]
    
    1. Implementing a gradient penalty term in the Discriminator loss function to enforce smoothness in the Discriminator's decision boundary can significantly help the model a lot, I will point you to this repo for more information Ref: github_repo

    2. Try lower values such as 0.8 for real images and slightly higher values like 0.2 for your fake images.

    3. Look at how to implement a feature matching loss to the Generators. This loss helps in computing statistics of the features extracted from an intermediate layer of the Discriminator. A reference, to guide you github_ref_feature_matching_implementation

    4. I don't see any batching of any sort note I maybe wrong but you should consider it if you haven't.

    5. Reduce Discriminator Capacity: You said that you already reduced the number of filters and added dropout layers, which are good steps. However, you can further reduce the Discriminator's capacity by lowering the number of layers or the filter sizes. This significantly would help GNN.

    6. When implementing GNN , especially with public dataset or otherwise , training with Noise can be helpful, introduce noise to the inputs of the Discriminator (e.g., add Gaussian noise to the input images) to make the Discriminator's job harder.

    7. Finally Latent Dimensionality, play arround with different dimensions for the latent space (e.g., increase or decrease latent_dim)

    Training models is all about experimenting with in-build parameters