I'm having issues while including the discriminator to an implementation of an SRGAN
. While training on the Flickr dataset, I see the discriminator fails to learn anything early on (with the BCELoss
showing a value of 100) and never to recover. I played around it a bit and removed the sigmoid in the hopes of using BCEWithLogits
as the loss. This led to the loss varying wildly in the beginning and getting to zero.
What is a good method to debug the discriminator implementation? I suspect way I'm calling the discriminator in training to have an issue.
class DiscriminatorConvBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super(DiscriminatorConvBlock, self).__init__()
num_groups = 8
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=3, stride=stride, padding=1, bias=False),
nn.GroupNorm(num_groups, out_channels),
nn.LeakyReLU(0.2, False),
)
def forward(self, x):
out = self.conv1(x)
return out
class Discriminator(nn.Module):
def __init__(self, low_res_dim):
super(Discriminator, self).__init__()
img_d = int(low_res_dim / 4)
self.conv1 = nn.Sequential(nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False),
nn.LeakyReLU(0.2, False),
)
self.conv2 = DiscriminatorConvBlock(64, 64, 2)
self.conv3 = DiscriminatorConvBlock(64, 128, 1)
self.conv4 = DiscriminatorConvBlock(128, 128, 2)
self.conv5 = DiscriminatorConvBlock(128, 256, 1)
self.conv6 = DiscriminatorConvBlock(256, 256, 2)
self.conv7 = DiscriminatorConvBlock(256, 512, 1)
self.conv8 = DiscriminatorConvBlock(512, 512, 2)
self.dense1 = nn.Linear(512 * img_d * img_d , 1024)
self.leakyRelu = nn.LeakyReLU(0.2, False)
self.dense2 = nn.Linear(1024 , 1)
self.drop = nn.Dropout(0.3)
def forward(self, x):
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
out = self.conv4(out)
out = self.conv5(out)
out = self.conv6(out)
out = self.conv7(out)
out = self.conv8(out)
out = out.view(-1, out.size(1) * out.size(2) * out.size(3))
out = self.leakyRelu(self.dense1(out))
out = self.dense2(out)
out = torch.clamp_(out, 0.0, 1.0)
return out
gen_model = Generator().to(device)
disc_model = Discriminator(low_res).to(device)
# VGG terms
vgg = models.vgg19(pretrained=True).to(device)
feature_nodes = ["features.35"]
feature_extractor = create_feature_extractor(vgg, feature_nodes)
feature_extractor_nodes = feature_nodes
normalizeT = transforms.Normalize([ 0.5, 0.5, 0.5 ], [ 0.5, 0.5, 0.5 ])
for model_parameters in feature_extractor.parameters():
model_parameters.requires_grad = False
feature_extractor.eval()
gen_optimizer = optim.Adam(gen_model.parameters(),lr=1e-4)
disc_optimizer = optim.Adam(disc_model.parameters(),lr=1e-5)
gen_scheduler = CosineAnnealingWarmRestarts(gen_optimizer,
T_0 = 8,# Number of iterations for the first restart
T_mult = 1, # A factor increases TiTi after a restart
eta_min = 1e-5) # Minimum learning rate
disc_scheduler = CosineAnnealingWarmRestarts(disc_optimizer,
T_0 = 8,# Number of iterations for the first restart
T_mult = 1, # A factor increases TiTi after a restart
eta_min = 1e-6) # Minimum learning rate
mse_loss = nn.MSELoss()
vgg_loss = nn.MSELoss()
disc_loss = nn.BCEWithLogitsLoss()
disc_loss_generator = nn.BCEWithLogitsLoss()
gen_optimizer.zero_grad()
for epoch in range(num_epochs):
gen_scheduler.step()
disc_scheduler.step()
for i, data in enumerate(tqdm.tqdm(dataloader)):
input_images, labels = data
# forward pass
input_images = input_images.to(device)
lowres_images = transforms.Resize(low_res)(input_images)
gen_highres_images = gen_model(lowres_images.to(device))
for model_parameters in disc_model.parameters():
model_parameters.requires_grad = True
# Discriminator
disc_model.zero_grad()
actual_label = disc_model(input_images.to(device))
# Adversarial loss
d2_loss = (disc_loss(actual_label, torch.ones_like(actual_label,dtype=torch.float)))
d2_loss.backward()
generated_label = disc_model(gen_highres_images.to(device))
d1_loss = (disc_loss(generated_label, torch.zeros_like(generated_label,dtype=torch.float)))
d1_loss.backward(retain_graph=True)
errD = d2_loss + d1_loss
disc_optimizer.step()
gen_model.zero_grad()
# Perceptual loss
mse = mse_loss(normalizeT(gen_highres_images), normalizeT(input_images))
vgg_losses = []
sr_feature = feature_extractor(normalizeT(input_images))
gt_feature = feature_extractor(normalizeT(gen_highres_images))
for i in range(len(feature_extractor_nodes)):
vgg_losses.append(vgg_loss(sr_feature[feature_extractor_nodes[i]],
gt_feature[feature_extractor_nodes[i]]))
for model_parameters in disc_model.parameters():
model_parameters.requires_grad = False
actual_generated_label = disc_model(gen_highres_images.to(device))
gen_disc_loss = disc_loss_generator(actual_generated_label, torch.ones_like(actual_label,dtype=torch.float))
generator_loss = vgg_losses[0] + mse + gen_disc_loss
generator_loss.backward()
gen_optimizer.step()
gen_optimizer.zero_grad()
torch.cuda.empty_cache()
The problem is out = torch.clamp_(out, 0.0, 1.0)
. This doesn't make sense with what you want the model to do and the loss you are using.
BCEWithLogitsLoss
applies a sigmoid. The sigmoid function returns [0.5, 1]
on the range [0, 1]
you are clamping the output to. You are essentially forcing the model to predict class 1
with >=50% confidence for every example.