I am trying to implement this function but have had no luck. There is a VAE model that I am using, and along with it, there are encoder and decoder. I'm freezing the weights of the VAE decoder, and trying to change a latent vector which is updated using the function optimize_latent_vector(model, inp__, num_epochs=50, learning_rate=0.01). Now, there is some error regarding this piece of code: RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
class VAE_GD_Loss(nn.Module):
def __init__(self):
super(VAE_GD_Loss, self).__init__()
def forward(self, bad_seg, recons_mask, vector):
# l2 normed squared and the soft dice loss are calculated
loss = torch.sum(vector**2)+Soft_Dice_Loss(recons_mask, bad_seg)
return loss
def optimize_latent_vector(model, inp__, num_epochs=50, learning_rate=0.01):
inp__ = inp__.to(device).requires_grad_(True)
# Encode and reparameterize to get initial latent vector
with torch.no_grad():
mu, log_var = model.encoder(inp__)
z_latent_vect = model.reparameterize(mu, log_var)
optimizer_lat = torch.optim.Adam([z_latent_vect], lr=learning_rate)
dec_only = model.decoder
for epoch in range(num_epochs):
optimizer_lat.zero_grad()
dec_only.eval()
# Decode from latent vector
recons_mask = dec_only(z_latent_vect)
# Calculate loss
VGLoss = VAE_GD_Loss()
loss = VGLoss(inp__, recons_mask, z_latent_vect)
# loss = Variable(loss, requires_grad=True)
# Backpropagation
loss.backward()
optimizer_lat.step()
print(f"Epoch {epoch}: Loss = {loss.item()}")
return z_latent_vect
If we uncomment the line loss = Variable(loss, requires_grad=True), then the code runs, but it doesn't minimize the loss whatsoever. I want to update the latent vector in such a way so that it follows the constraint set in the loss function. Any leads would help!
I think your z_latent_vect
is not enabled for gradient computation at all. It is initialised in a no_grad()
block and is detached from the rest of the computation graph. Defining it as a torch.nn.Parameter
should do the trick. At least I can see the loss decrease on a very simple VAE that I defined.
def optimize_latent_vector(model, inp__, num_epochs=50, learning_rate=0.01):
inp__ = inp__.to(device)
with torch.no_grad():
mu, log_var = model.encoder(inp__)
z_latent_vect = model.reparameterize(mu, log_var)
z_latent_vect = torch.nn.Parameter(z_latent_vect.clone(), requires_grad=True)
optimizer_lat = optim.Adam([z_latent_vect], lr=learning_rate)
dec_only = model.decoder
VGLoss = VAE_GD_Loss()
for epoch in range(num_epochs):
optimizer_lat.zero_grad()
dec_only.eval()
recons_mask = dec_only(z_latent_vect)
loss = VGLoss(inp__, recons_mask, z_latent_vect)
loss.backward()
# print(loss.item()) you should see it reduce here
optimizer_lat.step()
return z_latent_vect