I am implementing a conditional gan for image generation with text embedding from scratch and I am getting the above error exactly in the BatchNorm1d layer from the embedding_layers in the generator
generator class :
import torch.nn as nn
class Generator(nn.Module):
def __init__(self, embedding_dim=300, latent_dim=100, image_size=64, num_channels=3):
super(Generator, self).__init__()
self.embedding_size = embedding_dim
self.latent_dim = latent_dim
self.image_size = image_size
# Define embedding processing layers
self.embedding_layers = nn.Sequential(
nn.Linear(embedding_dim,latent_dim),
nn.BatchNorm1d(latent_dim),
nn.LeakyReLU(0.2, inplace=True)
)
# Define noise processing layers
self.noise_layers = nn.Sequential(
nn.Linear(latent_dim, image_size * image_size * 4),
nn.BatchNorm1d(image_size * image_size * 4),
nn.LeakyReLU(0.2, inplace=True)
)
# Define image processing layers
self.conv_layers = nn.Sequential(
nn.ConvTranspose2d(latent_dim + 256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.ReLU(True),
nn.ConvTranspose2d(512, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.ReLU(True),
nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.ReLU(True),
nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.ReLU(True),
nn.ConvTranspose2d(64, num_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.Tanh()
)
def get_latent_dim(self):
return self.latent_dim
def forward(self, embeddings,noise):
# Process embedding
embedding_features = self.embedding_layers(embeddings)
# Process noise
noise_features = self.noise_layers(noise)
# Combine features
features = torch.cat((embedding_features, noise_features), dim=1)
features = features.view(features.shape[0], -1, self.image_size // 16, self.image_size // 16)
# Generate image
image = self.conv_layers(features)
return image
discriminator class:
import torch.nn as nn
class Discriminator(nn.Module):
def __init__(self, embedding_dim=300, image_size=64, num_channels=3):
super(Discriminator, self).__init__()
# Define image processing layers
self.conv_layers = nn.Sequential(
nn.Conv2d(num_channels, 64, kernel_size=4, stride=2, padding=1),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(256, 512, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(512, 1024, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(1024),
nn.LeakyReLU(0.2, inplace=True),
nn.Conv2d(1024, 1, kernel_size=4, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
# Define embedding processing layers
self.embedding_layers = nn.Sequential(
nn.Linear(embedding_dim, image_size * image_size),
nn.BatchNorm1d(image_size * image_size),
nn.LeakyReLU(0.2, inplace=True)
)
def forward(self, images, embeddings):
# Process image
image_features = self.conv_layers(images)
# Process embedding
embedding_features = self.embedding_layers(embeddings)
embedding_features = embedding_features.view(embedding_features.shape[0], 1, 64, 64)
# Combine features
features = torch.cat((image_features, embedding_features), dim=1)
# Classify
classification = self.classification_layers(features).view(features.shape[0], -1)
validity = self.validity_layers(features).view(features.shape[0], -1)
return validity, classification
train function:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision.utils import save_image
from tqdm import tqdm
def train_gan(generator, discriminator, dataset, batch_size, num_epochs, device):
"""
Trains a conditional GAN with a generator and a discriminator using a PyTorch dataset containing text embeddings and images.
"""
# Set up loss functions and optimizers
adversarial_loss = nn.BCELoss()
generator_optimizer = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
# Set up data loader
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
generator.to(device)
discriminator.to(device)
# Train the GAN
for epoch in range(num_epochs):
for i, data in enumerate(tqdm(data_loader)):
# Load data and labels onto the device
text_embeddings = data['text_embedding'].to(device)
real_images = data['image'].to(device)
# Generate fake images using the generator and the text embeddings
noise = torch.randn(batch_size,generator.latent_dim).to(device)
fake_images = generator(text_embeddings,noise)
# Train the discriminator
discriminator_optimizer.zero_grad()
real_labels = torch.ones(real_images.size(0), 1).to(device)
fake_labels = torch.zeros(fake_images.size(0), 1).to(device)
real_predictions = discriminator(real_images, text_embeddings)
real_loss = adversarial_loss(real_predictions, real_labels)
fake_predictions = discriminator(fake_images.detach(), text_embeddings)
fake_loss = adversarial_loss(fake_predictions, fake_labels)
discriminator_loss = real_loss + fake_loss
discriminator_loss.backward()
discriminator_optimizer.step()
# Train the generator
generator_optimizer.zero_grad()
fake_predictions = discriminator(fake_images, text_embeddings)
generator_loss = adversarial_loss(fake_predictions, real_labels)
generator_loss.backward()
generator_optimizer.step()
# Save generated images and model checkpoints every 500 batches
if i % 500 == 0:
with torch.no_grad():
fake_images = generator(text_embeddings[:16]).detach().cpu()
save_image(fake_images, f"images\generated_images_epoch_{epoch}_batch_{i}.png", normalize=True, nrow=4)
torch.save(generator.state_dict(), f"images\generator_checkpoint_epoch_{epoch}_batch_{i}.pt")
torch.save(discriminator.state_dict(), f"images\discriminator_checkpoint_epoch_{epoch}_batch_{i}.pt")
# Print loss at the end of each epoch
print(f"Epoch [{epoch+1}/{num_epochs}] Discriminator Loss: {discriminator_loss.item()}, Generator Loss: {generator_loss.item()}")
main
# defining hyperparamter
torch.cuda.empty_cache()
embedding_dim=768
img_size=512
latent_dim=200
batch_size=32
num_epochs=100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
#define the main components
generator=Generator(embedding_dim=embedding_dim, latent_dim=latent_dim, image_size=img_size)
discriminator=Discriminator(embedding_dim=embedding_dim,image_size=img_size)
train_gan(generator=generator,
discriminator=discriminator,
dataset=dataset,
batch_size=batch_size,
num_epochs=num_epochs,
device=device,)
as for my dataset, it consists of images and text embeddings with the following shape
torch.Size([3, 512, 512])
torch.Size([1, 768])
If you refer to https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm1d.html
it says that input either (N, C)
or (N, C,L)
.
Your input on the other hand, is of shape (batch_size, 1, emb), which means that C
is 1
, and that is giving you the error,
what you need to do is to remove the extra dim
from torch import nn
import torch
embedding_dim = 768
latent_dim = 200
batch_size = 10
inputs = torch.randn(batch_size, 1, 768).squeeze()
model =nn.Sequential(
nn.Linear(embedding_dim,latent_dim),
nn.BatchNorm1d(latent_dim),
nn.LeakyReLU(0.2, inplace=True)
)
print(model(inputs).shape)