I get the following error for a GAN model I am using to perform image colorization. It uses the LAB color space as is common in image colorization. The generator generates the a ad b channels for a given L channel. The discriminator is fed all three channels after concatenation.
RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.cuda.FloatTensor [64, 64, 128, 128]], which is output 0 of ReluBackward0, is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).
I believe the error is due to the skip connections but I cannot quite put my finger on it. Any help would be appreciated!
Here is the model:
class NetGen(nn.Module):
'''Generator'''
def __init__(self):
super(NetGen, self).__init__()
self.conv1 = nn.Conv2d(1, 64, 3, stride=2, padding=1, bias=False)
self.bnorm1 = nn.BatchNorm2d(64)
self.relu1 = nn.LeakyReLU(0.1)
self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False)
self.bnorm2 = nn.BatchNorm2d(128)
self.relu2 = nn.LeakyReLU(0.1)
self.conv3 = nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False)
self.bnorm3 = nn.BatchNorm2d(256)
self.relu3 = nn.LeakyReLU(0.1)
self.conv4 = nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False)
self.bnorm4 = nn.BatchNorm2d(512)
self.relu4 = nn.LeakyReLU(0.1)
self.conv5 = nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False)
self.bnorm5 = nn.BatchNorm2d(512)
self.relu5 = nn.LeakyReLU(0.1)
self.deconv6 = nn.ConvTranspose2d(512, 512, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm6 = nn.BatchNorm2d(512)
self.relu6 = nn.ReLU()
self.deconv7 = nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm7 = nn.BatchNorm2d(256)
self.relu7 = nn.ReLU()
self.deconv8 = nn.ConvTranspose2d(256, 128, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm8 = nn.BatchNorm2d(128)
self.relu8 = nn.ReLU()
self.deconv9 = nn.ConvTranspose2d(128, 64, 3, stride=2, padding=1, output_padding=1, bias=False)
self.bnorm9 = nn.BatchNorm2d(64)
self.relu9 = nn.ReLU()
self.deconv10 = nn.ConvTranspose2d(64, 2, 3, stride=2, padding=1, output_padding=1, bias=False)
self.tanh = nn.Tanh()
def forward(self, x):
h = x
h = self.conv1(h)
h = self.bnorm1(h)
h = self.relu1(h)
pool1 = h
h = self.conv2(h)
h = self.bnorm2(h)
h = self.relu2(h)
pool2 = h
h = self.conv3(h)
h = self.bnorm3(h)
h = self.relu3(h)
pool3 = h
h = self.conv4(h)
h = self.bnorm4(h)
h = self.relu4(h)
pool4 = h
h = self.conv5(h)
h = self.bnorm5(h)
h = self.relu5(h)
h = self.deconv6(h)
h = self.bnorm6(h)
h = self.relu6(h)
h += pool4
h = self.deconv7(h)
h = self.bnorm7(h)
h = self.relu7(h)
h += pool3
h = self.deconv8(h)
h = self.bnorm8(h)
h = self.relu8(h)
h += pool2
h = self.deconv9(h)
h = self.bnorm9(h)
h = self.relu9(h)
h += pool1
h = self.deconv10(h)
h = self.tanh(h)
return h
class NetDis(nn.Module):
'''Discriminator'''
def __init__(self):
super(NetDis, self).__init__()
self.main = nn.Sequential(
nn.Conv2d(3, 64, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(64),
nn.LeakyReLU(0.1),
nn.Conv2d(64, 128, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(128),
nn.LeakyReLU(0.1),
nn.Conv2d(128, 256, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(256),
nn.LeakyReLU(0.1),
nn.Conv2d(256, 512, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1),
nn.Conv2d(512, 512, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1),
nn.Conv2d(512, 512, 8, stride=1, padding=0, bias=False),
nn.BatchNorm2d(512),
nn.LeakyReLU(0.1),
nn.Conv2d(512, 1, 1, stride=1, padding=0, bias=False),
nn.Sigmoid()
)
def forward(self, x):
return self.main(x)
Here is the weight init function:
def weights_init(m):
classname = m.__class__.__name__
if classname.find('Conv') != -1:
nn.init.normal_(m.weight.data, 0.0, 0.02)
elif classname.find('BatchNorm') != -1:
nn.init.normal_(m.weight.data, 1.0, 0.02)
nn.init.constant_(m.bias.data, 0)
Here is the training and validation code:
class Trainer:
def __init__(self, epochs, batch_size, learning_rate, num_workers):
self.epochs = epochs
self.batch_size = batch_size
self.learning_rate = learning_rate
self.num_workers = num_workers
self.train_paths = train_paths
self.val_paths = val_paths
self.real_label = 1
self.fake_label = 0
def train(self):
train_dataset = ColorizeData(paths=self.train_paths)
train_dataloader = DataLoader(train_dataset, batch_size=self.batch_size, num_workers=self.num_workers,pin_memory=True, drop_last = True)
# Model
model_G = NetGen().to(device)
model_D = NetDis().to(device)
model_G.apply(weights_init)
model_D.apply(weights_init)
optimizer_G = torch.optim.Adam(model_G.parameters(),
lr=self.learning_rate, betas=(0.5, 0.999),
eps=1e-8, weight_decay=0)
optimizer_D = torch.optim.Adam(model_D.parameters(),
lr=self.learning_rate, betas=(0.5, 0.999),
eps=1e-8, weight_decay=0)
criterion = nn.BCELoss()
L1 = nn.L1Loss()
model_G.train()
model_D.train()
# train loop
for epoch in range(self.epochs):
print("Starting Training Epoch " + str(epoch + 1))
for i, data in enumerate(tqdm(train_dataloader)):
inputs, input_ab, input_l = data
inputs = inputs.to(device)
input_ab = input_ab.to(device)
input_l = input_l.to(device)
model_D.zero_grad()
label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
output = model_D(torch.cat([input_l, input_ab], dim=1))
errD_real = criterion(torch.squeeze(output), label)
errD_real.backward()
fake = model_G(input_l)
label.fill_(self.fake_label)
output = model_D(torch.cat([input_l, fake.detach()], dim=1))
errD_fake = criterion(torch.squeeze(output), label)
errD_fake.backward()
errD = errD_real + errD_fake
optimizer_D.step()
model_G.zero_grad()
label.fill_(self.real_label)
output = model_D(torch.cat([input_l, fake], dim=1))
errG = criterion(torch.squeeze(output), label)
errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
errG = errG + 100 * errG_L1
errG.backward()
optimizer_G.step()
print(f'Training: Epoch {epoch + 1} \t\t Discriminator Loss: {\
errD / len(train_dataloader)} \t\t Generator Loss: {\
errG / len(train_dataloader)}')
if (epoch + 1) % 1 == 0:
errD_val, errG_val, val_len = self.validate(model_D, model_G, criterion, L1)
print(f'Validation: Epoch {epoch + 1} \t\t Discriminator Loss: {\
errD_val / val_len} \t\t Generator Loss: {\
errG_val / val_len}')
torch.save(model_G.state_dict(), '../Results/Model_GAN/Generator/saved_model_' + str(epoch + 1) + '.pth')
torch.save(model_D.state_dict(), '../Results/Model_GAN/Discriminator/saved_model_' + str(epoch + 1) + '.pth')
def validate(self, model_D, model_G, criterion, L1):
model_G.eval()
model_D.eval()
with torch.no_grad():
valid_loss = 0.0
val_dataset = ColorizeData(paths=self.val_paths)
val_dataloader = DataLoader(val_dataset, batch_size=self.batch_size, num_workers=self.num_workers, pin_memory=True, drop_last = True)
for i, data in enumerate(val_dataloader):
inputs, input_ab, input_l = data
inputs = inputs.to(device)
input_ab = input_ab.to(device)
input_l = input_l.to(device)
label = torch.full((self.batch_size,), self.real_label, dtype=torch.float, device=device)
output = model_D(torch.cat([input_l, input_ab], dim=1))
errD_real = criterion(torch.squeeze(output), label)
fake = model_G(input_l)
label.fill_(self.fake_label)
output = model_D(torch.cat([input_l, fake.detach()], dim=1))
errD_fake = criterion(torch.squeeze(output), label)
errD = errD_real + errD_fake
label.fill_(self.real_label)
output = model_D(torch.cat([input_l, fake], dim=1))
errG = criterion(torch.squeeze(output), label)
errG_L1 = L1(fake.view(fake.size(0),-1), input_ab.view(input_ab.size(0),-1))
errG = errG + 100 * errG_L1
return errD, errG, len(val_dataloader)
EDIT
As suggested by @manaclan here is the code I use to run the pipeline:
trainer = Trainer(epochs = 100, batch_size = 64, learning_rate = 0.0002, num_workers = 2)
trainer.train()
Here is the data loader:
class ColorizeData(Dataset):
def __init__(self, paths):
self.input_transform = T.Compose([T.ToTensor(),
T.Resize(size=(256,256)),
T.Grayscale(),
T.Normalize((0.5), (0.5))
])
self.lab_transform = T.Compose([T.ToTensor(),
T.Resize(size=(256,256)),
T.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
self.paths = paths
def __len__(self) -> int:
return len(self.paths)
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor]:
image = Image.open(self.paths[index]).convert("RGB")
input_image = self.input_transform(image)
image_lab = rgb2lab(image)
image_lab = self.lab_transform(image_lab)
image_l = image_lab[0, :, :]
image_ab = image_lab[1:3, :, :]
return (input_image.float(), image_ab.float(), image_l.float().reshape(1, 256, 256))
Here are the imports:
from typing import Tuple
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import torch
import numpy as np
import os
import torch.nn as nn
import torchvision.models as models
import torchvision
import torch.nn.functional as functional
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
from PIL import Image
import glob
import matplotlib.pyplot as plt
from tqdm import tqdm
from skimage.color import lab2rgb, rgb2lab, rgb2gray
from skimage import io
from torchvision.transforms.functional import resize
To reproduce the error, just use any dataset of color images. I have the following code to get my train, test, and validation images from the folder "Dataset":
path = "../Dataset/"
paths = np.array(glob.glob(path + "/*.jpg"))
rand_indices = np.random.permutation(len(paths)) # Number of images in dataset
train_indices, val_indices, test_indices = rand_indices[:3600], rand_indices[3600:4000], rand_indices[4000:]
train_paths = paths[train_indices]
val_paths = paths[val_indices]
test_paths = paths[test_indices]
NOTE: I am using Google Colab, maybe this might be a potential problem? Also, I am using torch version 1.10.0+cu111. I did use a sequential model without skip connections for the generator before this, and I did not have this error then.
So apparently, the problem is the inplace skip connection written as h += poolX.
Writing this update out of place as h = h + poolX
fixed it. h is needed for gradient calculation in some layers, so inplace modification will mess it up.