I tried to build my very own GAN in PyTorch. I wanted to see how my model learns to generate images over time, I tried to save images it created after each epoch but it saved the same image after every epoch. I guess I do save the same image everytime. You can see first 3 epochs' images In addition, as you can see, it combines all images to save, can I choose only 1?
class Discriminator(nn.Module):
def __init__(self, img_dim):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(img_dim, 128),
nn.LeakyReLU(0.1),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.disc(x)
class Generator(nn.Module):
def __init__(self, z_dim, img_dim):
super().__init__()
self.gen = nn.Sequential(
nn.Linear(z_dim, 512),
nn.LeakyReLU(0.1),
nn.Linear(512, 1024),
nn.LeakyReLU(0.1),
nn.Linear(1024, img_dim),
nn.Tanh()
)
def forward(self, x):
return self.gen(x)
lr = 3e-4
z_dim = 64
image_dim = 256 * 256 * 3
batch_size = 32
num_epochs = 16
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
# dataset = load_dataset(data_path="mountain_dataset", transform=transform)
loader = DataLoader(dataset, batch_size=32, num_workers=0, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
second part:
from torchvision.utils import save_image
step = 0
for epoch in range(num_epochs):
for batch_idx, real in enumerate(dataset):
real = real.view(-1, image_dim).to(device)
batch_size = real.shape[0]
### Train Discriminator: max log(D(real)) + log(1 - D(G(z)))
noise = torch.randn(batch_size, z_dim).to(device)
fake = gen(noise)
disc_real = disc(real).view(-1)
lossD_real = criterion(disc_real, torch.ones_like(disc_real))
disc_fake = disc(fake).view(-1)
lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
lossD = (lossD_real + lossD_fake) / 2
disc.zero_grad()
lossD.backward(retain_graph=True)
opt_disc.step()
### Train Generator maximize log(D(G(z)))
output = disc(fake).view(-1)
lossG = criterion(output, torch.ones_like(output))
gen.zero_grad()
lossG.backward()
opt_gen.step()
if batch_idx == 0:
print(
f"Epoch: [{epoch+1}/{num_epochs}]"
)
with torch.no_grad():
fake = gen(fixed_noise).reshape(-1, 3, 256, 256)
data = real.reshape(-1, 3, 256, 256)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
# Convert the NumPy array to a PyTorch tensor
img_grid_fake_tensor = img_grid_fake
# Save the PyTorch tensor as an image
save_image(img_grid_fake_tensor, f"generated_images/epoch{epoch}.png", normalize=True)
step += 1
First, in the line for batch_idx, real in enumerate(dataset):
you iterate over the dataset.
I.e., real
represents one image, not one batch. If you add the line print(real.shape)
as the first line after the loop, this will print torch.Size([3, 256, 256])
which is one image rather than one batch. Accordingly, your batch_size will always be three and actually be the number of channels.
So you have to change this line to for batch_idx, real in enumerate(loader):
. Then, the print will give you torch.Size([32, 3, 256, 256])
, which is what you actually want.
After that modification, the saved images are different after each epoch for the first few epochs. I tested this with Cifar10 upscaled:
dataset = torchvision.datasets.CIFAR10(root="dataset/", transform=torchvision.transforms.Compose([
torchvision.transforms.Resize(image_shape),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize((0.5,), (0.5,))
]), download=True)
dataset = torch.utils.data.Subset(dataset, range(0, 1000))
loader = DataLoader(dataset, batch_size=32, shuffle=True, pin_memory=True)
However, after 3 epochs, the generated images stay indeed the same. This is because your model stops learning.
Epoch: [1/16] Loss D: 0.6715, Loss G: 3.8073
Epoch: [2/16] Loss D: 50.3202, Loss G: 0.0000
Epoch: [3/16] Loss D: 50.0000, Loss G: 0.0000
Epoch: [4/16] Loss D: 50.0030, Loss G: 0.0000
Epoch: [5/16] Loss D: 50.0000, Loss G: 0.0000
You will have to modify your architecture in order to keep training for more epochs. For instance, your discriminator is much smaller than the generator. As a starting point, you might want to have both models to have a similar number of parameters. Your can see the parameter count like this:
print(
f"Discriminator Parameters:\t{sum(p.numel() for p in disc.parameters())}\n"
f"Generator Parameters:\t\t{sum(p.numel() for p in gen.parameters())}"
)
For instance, increasing the discriminator like this gives me an additional three epochs of learning:
class Discriminator(nn.Module):
def __init__(self, img_dim):
super().__init__()
self.disc = nn.Sequential(
nn.Linear(img_dim, 512),
nn.LeakyReLU(0.1),
nn.Linear(512, 1024),
nn.LeakyReLU(0.1),
nn.Linear(1024, 1),
nn.Sigmoid()
)
With the line img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
, you create a grid of all 32 (batch_size) images which is then saved to a file.
If you only want to randomly choose four of these for saving, you could do this like that:
fake = gen(fixed_noise).reshape(-1, 3, 256, 256)
# Select 4 random images to print
fake = random.choices(fake, k=4)
img_grid_fake = torchvision.utils.make_grid(fake, normalize=True, nrow=2)
If you only want to save a single image instead of a grid:
# Print only image number 6
index = 5 # For random: random.randint(0, batch_size-1)
fake = gen(fixed_noise[index]).reshape(3, 256, 256)
save_image(fake, f"generated_images/epoch{epoch}.png", normalize=True)
I deleted the lines
data = real.reshape(-1, 3, 256, 256)
img_grid_real = torchvision.utils.make_grid(data, normalize=True)
are unnecessary because unless you actually want to print some real samples, too.