I modified the example from the Pytorch VAE example to be a convolutional network. I then wanted to implement this in FastAI.
class convVAE(nn.Module):
def __init__(self, dim_z=20):
super(convVAE, self).__init__()
self.cv1 = nn.Conv2d(1, 32, 3, stride=2)
self.cv2 = nn.Conv2d(32, 64, 3, stride=2)
self.fc31 = nn.Linear(2304, dim_z)
self.fc32 = nn.Linear(2304, dim_z)
self.fc4 = nn.Linear(dim_z, 2304)
self.cv5 = nn.ConvTranspose2d(64, 32, 3, stride=2)
self.cv6 = nn.ConvTranspose2d(32, 1, 3, stride=2, output_padding=1)
def encode(self, x):
h1 = F.leaky_relu(self.cv1(x))
h2 = F.leaky_relu(self.cv2(h1)).view(-1, 2304)
return self.fc31(h2), self.fc32(h2)
def reparameterize(self, mu, logvar):
std = torch.exp(0.5*logvar)
eps = torch.randn_like(std)
return mu + eps*std
def decode(self, z):
h5 = F.leaky_relu(self.fc4(z)).view(-1, 64, 6, 6)
h6 = F.leaky_relu(self.cv5(h5))
return torch.sigmoid(self.cv6(h6))
def forward(self, x):
mu, logvar = self.encode(x)
z = self.reparameterize(mu, logvar)
return self.decode(z).view(-1, 784), mu, logvar
def get_loss(res,y):
y_hat, mu, logvar = res
BCE = F.binary_cross_entropy(
y.view(-1, 784),
KLD = -0.5 * torch.sum(1 + logvar -
mu.pow(2) - logvar.exp())
return BCE + KLD
block = DataBlock(
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=(lambda x: x),
batch_tfms=aug_transforms(mult=2., do_flip=False))
path = untar_data(URLs.MNIST)
loaders = block.dataloaders(path/“training”,num_workers=0,bs=32)
loaders.train.show_batch(max_n=4, nrows=1)
mdl = convVAE(5)
learn = Learner(loaders, mdl, loss_func = convVAE.get_loss)
learn.fit(1, cbs=ShortEpochCallback())
The gradient is not computing from the loss, as the parameters all become NaN after one step. The loss function does compute but was relatively large O(1e6). The model and loss function work in the native Pytorch implementation.
There is a mistake in your BCE calculation:
BCE = F.binary_cross_entropy(
y.view(-1, 784), # this should be your model prediction
y_hat, # this should be the ground truth
A simple fix is to swap the two arguments.