I am still new to GANs and trying to implement the One to Many CGAN, on the MNIST Dataset, to generate a sequence of images (with the number of output images equal to the number of generators) from a sum. For example, A 4-generators 1-discriminator CGANS model will generate 4 digit number images: [3, 2, 6] if the inputs are [noise, noise, noise, noise] and condition = 11.
However, I encountered a problem where the Discriminator's loss kept decreasing to near zero, but the Generators' increased. Therefore, the Generators produce noises instead of meaningful images.
I supposed this is due to the fact that the Discriminator was too good, so I added Dropout layers and decreased the number of filters but nothing changed.
I am trying to get this model to work properly as I described above.
Custom Dataset: This dataset consists of X (data_size, digit_num, channels, height, width) with digit_num representing the number of digits in the input sequence. And Y which is total possible outcomes (digit_num * 9 + 1).
digit_num = 3
label_num = 9 * digit_num + 1
data_size = 120000
dataset = SumMNISTDataset(
"mnist",
0,
digit_num,
data_size,
transforms.Compose([transforms.Grayscale(),
transforms.Normalize(127.5, 127.5)]),
)
dataloader = DataLoader(dataset, batch_size, True, drop_last=True)
Generator and Discriminator Architecture: The Model in the paper uses a deeper network for a high resolution dataset. In my case, I scale it down for the MNIST dataset.
class Generator(nn.Module):
def __init__(self, latent_dim, filter_num, label_num, embed_num=50, bias=False):
super().__init__()
self.pre_main = nn.Sequential(
# 7 x 7 x 128
nn.ConvTranspose2d(latent_dim, filter_num * 4, 7, 1, 0, bias=bias),
nn.BatchNorm2d(filter_num * 4),
nn.LeakyReLU(0.2),
)
self.condition = nn.Sequential(
# 1 x 50
nn.Embedding(label_num, embed_num),
nn.Linear(embed_num, 49, bias=bias),
)
self.main = nn.Sequential(
# 14 x 14 x 64
nn.ConvTranspose2d(filter_num * 4 + 1, filter_num * 2, 4, 2, 1, bias=bias),
nn.BatchNorm2d(filter_num * 2),
nn.LeakyReLU(0.2),
# 28 x 28 x 1
nn.ConvTranspose2d(filter_num * 2, 1, 4, 2, 1, bias=bias),
nn.Tanh(),
)
def forward(self, x, y):
y = self.condition(y).reshape(-1, 1, 7, 7)
x = self.pre_main(x)
x = torch.cat((x, y), dim=1)
x = self.main(x)
return x
class Discriminator(nn.Module):
def __init__(self, filter_num, label_num, embed_num=50, bias=True):
super().__init__()
self.condition = nn.Sequential(
# 28 x 28 x 50
nn.Embedding(label_num, embed_num),
nn.Linear(embed_num, 28 * 28, bias=bias),
)
self.main = nn.Sequential(
# 14 x 14 x 64
nn.Conv2d(2, filter_num, 3, 2, 1, bias=bias),
nn.BatchNorm2d(filter_num),
nn.LeakyReLU(0.2),
# 7 x 7 x 128
nn.Conv2d(filter_num, filter_num * 2, 3, 2, 1, bias=bias),
nn.BatchNorm2d(filter_num * 2),
nn.LeakyReLU(0.2),
# Dense
nn.Flatten(),
nn.Linear(7 * 7 * filter_num * 2, 1, bias=bias),
)
def forward(self, x, y):
y = self.condition(y).reshape(-1, 1, 28, 28)
x = torch.cat((x, y), dim=1)
x = self.main(x)
return x
Initializing: The number of generators is equal to the digits in the input sequence. Theoretically, if the number of generators equals to 1, than this model is equivalent to CGANs, however, the model still failed to converge even when there is 1 generator.
learning_rate = 0.0002
beta_1 = 0.5
latent_dim = 100
filter_num = 32
generator_num = digit_num
omega = 1 / generator_num
def weight_ini_G(model):
if type(model) == nn.Linear:
nn.init.constant_(model.weight.data, 1 / generator_num)
elif type(model) == nn.BatchNorm2d:
nn.init.constant_(model.weight.data, 1 / generator_num)
nn.init.constant_(model.bias.data, 0)
def weight_ini_D(model):
if type(model) == nn.Linear:
nn.init.normal_(model.weight.data, 0.0, 0.2)
elif type(model) == nn.BatchNorm2d:
nn.init.normal_(model.weight.data, 1.0, 0.2)
nn.init.constant_(model.bias.data, 0)
Gs = [
Generator(latent_dim, filter_num, label_num).to(device).apply(weight_ini_G)
for _ in range(generator_num)
]
D = Discriminator(filter_num, label_num).to(device).apply(weight_ini_D)
G_optimizers = [
optim.Adam(G.parameters(), learning_rate, betas=(beta_1, 0.999)) for G in Gs
]
D_optimizer = optim.Adam(D.parameters(), learning_rate, betas=(beta_1, 0.999))
bce = nn.BCEWithLogitsLoss()
l1 = nn.L1Loss()
Helper Functions: Just to clarify, the generate_hybrid function will calculate the mean of all images on the digit_num axis for training.
def generate_fake():
rand_labels = torch.randint(0, label_num, (batch_size, 1), device=device)
images = [
Gs[g](torch.randn((batch_size, latent_dim, 1, 1), device=device), rand_labels)
for g in range(generator_num)
]
images = torch.stack(images, axis=1).detach_()
return images, rand_labels
def generate_real():
images, labels = next(iter(dataloader))
return images.to(device), labels.to(device)
def generate_hybrid(images):
if images.shape[0] == digit_num:
images = torch.mean(images, dim=0)
elif images.shape[1] == digit_num:
images = torch.mean(images, dim=1)
return images
Update the Generator
def update_generators(real, fake):
ones = torch.ones((batch_size, 1), device=device)
f_images, f_labels = fake
r_images, _ = real
total_loss = 0
for g in range(generator_num):
hybrid_fake = generate_hybrid(f_images)
# r_image = r_images[:, g, :, :]
preds = D(hybrid_fake, f_labels)
bce_loss = bce(preds, ones)
# l1_loss = l1(f_image, r_image)
loss = bce_loss
Gs[g].zero_grad()
loss.backward()
G_optimizers[g].step()
total_loss += loss.item()
return total_loss / generator_num
Update the Discriminator
def update_discriminator(real, fake):
half_batch_size = batch_size // 2
zeros = torch.zeros((half_batch_size, 1), device=device)
ones = torch.ones((half_batch_size, 1), device=device)
f_images, f_labels = fake
r_images, r_labels = real
f_images = f_images[:half_batch_size]
f_labels = f_labels[:half_batch_size]
r_images = r_images[:half_batch_size]
r_labels = r_labels[:half_batch_size]
total_loss = 0
# Train on Real
hybrid_real = generate_hybrid(r_images)
real_preds = D(hybrid_real, r_labels)
bce_r_loss = bce(real_preds, ones)
D.zero_grad()
bce_r_loss.backward()
# Train of Fake
hybrid_fake = generate_hybrid(f_images)
fake_preds = D(hybrid_fake, f_labels)
bce_f_loss = bce(fake_preds, zeros)
bce_f_loss.backward()
D_optimizer.step()
total_loss = (bce_f_loss.item() + bce_r_loss.item()) / 2
return total_loss
Training the Model
D_losses = []
G_losses = []
epochs = 5
fixed_noise = torch.randn((4, latent_dim, 1, 1), device=device)
fixed_label = torch.randint(0, label_num, (4,), device=device)
for epoch in range(epochs):
print(f"Epoch {epoch + 1}:")
for batch in range(data_size // batch_size):
# Generate Fake Images
fake = generate_fake()
# Generate Real Images
real = generate_real()
D_loss = update_discriminator(real, fake)
fake = generate_fake()
G_loss = update_generators(real, fake)
if batch % 100 == 0:
print(
f"[Batch: {(batch + 1) * batch_size :7d}/{data_size} D_Loss: {D_loss} G_Loss: {G_loss}]"
)
generate_image(epoch, batch, fixed_noise, fixed_label)
D_losses.append(D_loss)
G_losses.append(G_loss)
I stopped the model at about 1 epoch due to meaningless results that it displayed.
Issues
One of the pain points in training a model is that they become powerful quickly that it is they just learn without processing the information. I would suggest few things.
learning rate
of the Discriminator relative to the Generators. This will slow down the Discriminator's training, giving the Generators more opportunity to learn.
For example:# Adjusted learning rates
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate * 0.1, betas=(beta_1, 0.999))
G_optimizers = [
optim.Adam(G.parameters(), lr=learning_rate * 2, betas=(beta_1, 0.999)) for G_ops in Gopss
]
Implementing a gradient penalty term
in the Discriminator loss function to enforce smoothness in the Discriminator's decision boundary can significantly help the model a lot, I will point you to this repo for more information Ref: github_repo
Try lower values such as 0.8 for real images and slightly higher values like 0.2 for your fake images.
Look at how to implement a feature matching loss to the Generators. This loss helps in computing statistics of the features extracted from an intermediate layer of the Discriminator. A reference, to guide you github_ref_feature_matching_implementation
I don't see any batching of any sort note I maybe wrong but you should consider it if you haven't.
Reduce Discriminator Capacity: You said that you already reduced the number of filters and added dropout layers, which are good steps. However, you can further reduce the Discriminator's capacity by lowering the number of layers or the filter sizes. This significantly would help GNN.
When implementing GNN , especially with public dataset or otherwise , training with Noise can be helpful, introduce noise to the inputs of the Discriminator (e.g., add Gaussian noise to the input images) to make the Discriminator's job harder.
Finally Latent Dimensionality, play arround with different dimensions for the latent space (e.g., increase or decrease latent_dim
)
Training models is all about experimenting with in-build parameters