pythondeep-learningpytorchautograd

Unable to update a latent vector using custom loss function in pytorch


I am trying to implement this function but have had no luck. There is a VAE model that I am using, and along with it, there are encoder and decoder. I'm freezing the weights of the VAE decoder, and trying to change a latent vector which is updated using the function optimize_latent_vector(model, inp__, num_epochs=50, learning_rate=0.01). Now, there is some error regarding this piece of code: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn


class VAE_GD_Loss(nn.Module):
    def __init__(self):
        super(VAE_GD_Loss, self).__init__()

    def forward(self, bad_seg, recons_mask, vector):
        # l2 normed squared and the soft dice loss are calculated
        loss = torch.sum(vector**2)+Soft_Dice_Loss(recons_mask, bad_seg)
        return loss

def optimize_latent_vector(model, inp__, num_epochs=50, learning_rate=0.01):
    inp__ = inp__.to(device).requires_grad_(True)
    # Encode and reparameterize to get initial latent vector
    with torch.no_grad():
        mu, log_var = model.encoder(inp__)
        z_latent_vect = model.reparameterize(mu, log_var)
    optimizer_lat = torch.optim.Adam([z_latent_vect], lr=learning_rate)
    dec_only = model.decoder
    
    for epoch in range(num_epochs):
        optimizer_lat.zero_grad()
        dec_only.eval()
        # Decode from latent vector
        recons_mask = dec_only(z_latent_vect)
        # Calculate loss
        VGLoss = VAE_GD_Loss()
        loss = VGLoss(inp__, recons_mask, z_latent_vect)
        # loss = Variable(loss, requires_grad=True)
        # Backpropagation
        loss.backward()
        optimizer_lat.step()
        print(f"Epoch {epoch}: Loss = {loss.item()}")
    
    return z_latent_vect

If we uncomment the line loss = Variable(loss, requires_grad=True), then the code runs, but it doesn't minimize the loss whatsoever. I want to update the latent vector in such a way so that it follows the constraint set in the loss function. Any leads would help!


Solution

  • I think your z_latent_vect is not enabled for gradient computation at all. It is initialised in a no_grad() block and is detached from the rest of the computation graph. Defining it as a torch.nn.Parameter should do the trick. At least I can see the loss decrease on a very simple VAE that I defined.

    def optimize_latent_vector(model, inp__, num_epochs=50, learning_rate=0.01):
        inp__ = inp__.to(device)
    
        with torch.no_grad():
            mu, log_var = model.encoder(inp__)
            z_latent_vect = model.reparameterize(mu, log_var)
    
        z_latent_vect = torch.nn.Parameter(z_latent_vect.clone(), requires_grad=True)
        optimizer_lat = optim.Adam([z_latent_vect], lr=learning_rate)
    
        dec_only = model.decoder
        VGLoss = VAE_GD_Loss()
    
        for epoch in range(num_epochs):
            optimizer_lat.zero_grad()
            dec_only.eval()
    
            recons_mask = dec_only(z_latent_vect)
            loss = VGLoss(inp__, recons_mask, z_latent_vect)
    
            loss.backward()
            # print(loss.item()) you should see it reduce here
            optimizer_lat.step()
    
        return z_latent_vect