pythonpytorchdcganstate-dict

RuntimeError: Error(s) in loading state_dict for Generator: Missing key(s) in state_dict


I was trying to train a DCGAN model using MNIST datasets, but I can't load the gen.state_dict() after I finished training.

import torch
import torch.nn as nn
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter
import torchvision
import os
from torch.autograd import Variable

workspace_dir = '/content/drive/My Drive/practice'
device=torch.device('cuda' if torch.cuda.is_available else 'cpu')
print(device)

img_size=64
channel_img=1
lr=2e-4
batch_size=128
z_dim=100
epochs=10
features_gen=64
features_disc=64
save_dir = os.path.join(workspace_dir, 'logs')
os.makedirs(save_dir, exist_ok=True)
import matplotlib.pyplot as plt
transforms=transforms.Compose([transforms.Resize(img_size),transforms.ToTensor(),transforms.Normalize(mean=(0.5,),std=(0.5,))])
train_data=datasets.MNIST(root='dataset/',train=True,transform=transforms,download=True)
train_loader=torch.utils.data.DataLoader(train_data,batch_size=batch_size,shuffle=True)
count=0
for x,y in train_loader:
  if count==5:
    break
  print(x.shape,y.shape)
  count+=1

class Discriminator(nn.Module):
  def __init__(self,channels_img,features_d):
    super(Discriminator,self).__init__()
    
    self.disc=nn.Sequential(
        #input:N * channels_img * 64 *64
        nn.Conv2d(channels_img,features_d,4,2,1),#paper didn't use batchnorm in the early layers in the discriminator features_d* 32 *32
        nn.LeakyReLU(0.2),
        self._block(features_d,features_d*2,4,2,1),#features_d*2 *16 *16
        self._block(features_d*2,features_d*4,4,2,1),#features_d*4 *8 *8
        self._block(features_d*4,features_d*8,4,2,1), #features_d*8 *4 *4
        nn.Conv2d(features_d*8,1,4,2,0),#1 * 1 *1
        nn.Sigmoid()

    )
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
        nn.Conv2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),
        nn.BatchNorm2d(out_channels),
        nn.LeakyReLU(0.2)
    )


  def forward(self,x):
    return self.disc(x)

class Generator(nn.Module):
  def __init__(self,Z_dim,channels_img,features_g):
    super(Generator,self).__init__()
    
    self.gen=nn.Sequential(
        #input :n * z_dim * 1 *1
        self._block(Z_dim,features_g*16,4,1,0),#features_g*16 * 4 * 4
        self._block(features_g*16,features_g*8,4,2,1),#features_g*8 * 8 * 8
        self._block(features_g*8,features_g*4,4,2,1),#features_g*4 * 16 * 16
        self._block(features_g*4,features_g*2,4,2,1),#features_g*2 * 32 * 32
        nn.ConvTranspose2d(features_g*2,channels_img,4,2,1), #
        nn.Tanh()# [-1 to 1] normalize the image
    )
  def _block(self,in_channels,out_channels,kernel_size,stride,padding):
    return nn.Sequential(
          nn.ConvTranspose2d(in_channels,out_channels,kernel_size,stride,padding,bias=False),#w'=(w-1)*s-2p+k
          nn.BatchNorm2d(out_channels),
          nn.ReLU()
      )
    
  def forward(self,x):
      return self.gen(x)


def initialize_weights(model):
  for m in model.modules():
    if isinstance(m,(nn.Conv2d,nn.ConvTranspose2d,nn.BatchNorm2d)):
      nn.init.normal_(m.weight.data,0.0,0.02)

gen=Generator(z_dim,channel_img,features_gen).to(device)
disc=Discriminator(channel_img,features_disc).to(device)
initialize_weights(gen)
initialize_weights(disc)
opt_gen=torch.optim.Adam(gen.parameters(),lr=lr,betas=(0.5,0.999))
opt_disc=torch.optim.Adam(disc.parameters(),lr=lr,betas=(0.5,0.999))
criterion=nn.BCELoss()
#fixed_noise=torch.randn(32,z_dim,1,1).to(device)
#writer_real=SummaryWriter(f"logs/real")
#writer_fake=SummaryWriter(f"logs/fake")
step=0
gen.train()
disc.train()


z_sample = Variable(torch.randn(100, z_dim,1,1)).cuda()
for epoch in range(2):
  for batch_idx,(real,_) in enumerate(train_loader):
    real=real.to(device)
    noise=torch.randn((batch_size,z_dim,1,1)).to(device)
    fake=gen(noise)
    
    #Train Discriminator max log(D(x)) + log(1-D(G(z)))
    disc_real=disc(real).reshape(-1)
    loss_disc_real=criterion(disc_real,torch.ones_like(disc_real))
    disc_fake=disc(fake).reshape(-1)
    loss_disc_fake=criterion(disc_fake,torch.zeros_like(disc_fake))
    loss_disc=(loss_disc_fake+loss_disc_real)/2
    disc.zero_grad()
    loss_disc.backward(retain_graph=True)
    opt_disc.step()

    #Train Generator  min log(1-D(G(z))) <--> max log(D(G(z)))
    output=disc(fake).reshape(-1)
    loss_gen=criterion(output,torch.ones_like(output))
    gen.zero_grad()
    loss_gen.backward()
    opt_gen.step()

    
    print(f'\rEpoch [{epoch+1}/{3}] {batch_idx+1}/{len(train_loader)} Loss_D: {loss_disc.item():.4f} Loss_G: {loss_gen.item():.4f}', end='')
  gen.eval()
  f_imgs_sample = (gen(z_sample).data + 1) / 2.0
  filename = os.path.join(save_dir, f'Epoch_{epoch+1:03d}.jpg')
  torchvision.utils.save_image(f_imgs_sample, filename, nrow=10)
  print(f' | Save some samples to {filename}.')
  # show generated image
  grid_img = torchvision.utils.make_grid(f_imgs_sample.cpu(), nrow=10)
  plt.figure(figsize=(10,10))
  plt.imshow(grid_img.permute(1, 2, 0))
  plt.show()
  gen.train()
  
  torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth'))
  torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth'))
  

I can't load the gen state_dict in this step:

# load pretrained model
#gen = Generator(z_dim,1,64)
gen=Generator(z_dim,channel_img,features_gen).to(device)
gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
gen.eval()
gen.cuda()

Here's the error:

RuntimeError                              Traceback (most recent call last)
<ipython-input-18-4bda27faa444> in <module>()
      5 
      6 #gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
----> 7 gen.load_state_dict(torch.load(os.path.join(workspace_dir, 'dcgan_g.pth')))
      8 #/content/drive/My Drive/practice/dcgan_g.pth
      9 gen.eval()

/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
   1050         if len(error_msgs) > 0:
   1051             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
-> 1052                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1053         return _IncompatibleKeys(missing_keys, unexpected_keys)
   1054 

***RuntimeError: Error(s) in loading state_dict for Generator:
    Missing key(s) in state_dict***: "gen.0.0.weight", "gen.0.1.weight", "gen.0.1.bias", "gen.0.1.running_mean", "gen.0.1.running_var", "gen.1.0.weight", "gen.1.1.weight", "gen.1.1.bias", "gen.1.1.running_mean", "gen.1.1.running_var", "gen.2.0.weight", "gen.2.1.weight", "gen.2.1.bias", "gen.2.1.running_mean", "gen.2.1.running_var", "gen.3.0.weight", "gen.3.1.weight", "gen.3.1.bias", "gen.3.1.running_mean", "gen.3.1.running_var", "gen.4.weight", "gen.4.bias". 
    Unexpected key(s) in state_dict: "disc.0.weight", "disc.0.bias", "disc.2.0.weight", "disc.2.1.weight", "disc.2.1.bias", "disc.2.1.running_mean", "disc.2.1.running_var", "disc.2.1.num_batches_tracked", "disc.3.0.weight", "disc.3.1.weight", "disc.3.1.bias", "disc.3.1.running_mean", "disc.3.1.running_var", "disc.3.1.num_batches_tracked", "disc.4.0.weight", "disc.4.1.weight", "disc.4.1.bias", "disc.4.1.running_mean", "disc.4.1.running_var", "disc.4.1.num_batches_tracked", "disc.5.weight", "disc.5.bias".

Solution

  • You saved the weights with the wrong names. That is you saved the generator's weights as dcgan_d.pth and likewise, saved the descriminator's weights as dcgan_g.pth :

      torch.save(gen.state_dict(), os.path.join(workspace_dir, f'dcgan_d.pth')) # should have been dcgan_g.pth
      torch.save(disc.state_dict(), os.path.join(workspace_dir, f'dcgan_g.pth')) # should have been dcgan_d.pth
    

    and thus when loading, you try to load the wrong weights :

    gen.load_state_dict(torch.load('/content/drive/My Drive/practice/dcgan_g.pth'))
    

    dcgan_g.pth contains the descriminators weights not your generators. First fix the wrong names when you save them. and second, simply rename them accordingly you should be fine.