I am trying to implement/learn how to implement contrastive loss. Currently my gradients are exploding into infinity and I think I must have misimplemented something. I was wondering if someone could take a look at my loss function and tell if they see an error
class ContrastiveLoss(nn.Module):
def __init__(self, temperature=0.5):
super(ContrastiveLoss, self).__init__()
self.temperature = temperature
def forward(self, projections_1, projections_2):
z_i = projections_1
z_j = projections_2
z_i_norm = F.normalize(z_i, dim=1)
z_j_norm = F.normalize(z_j, dim=1)
cosine_num = torch.matmul(z_i, z_j.T)
cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
cosine_similarity = cosine_num / cosine_denom
numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature)
denominator = cosine_similarity
diagonal_indices = torch.arange(denominator.size(0))
denominator[diagonal_indices, diagonal_indices] = 0
denominator = torch.exp(torch.sum(cosine_similarity, dim=1))
loss = -torch.log(numerator / denominator).sum()
return loss
Your implementation of cosine similarity is wrong. You can see this by inspecting the values of the cosine similarity matrix. Run the following:
import torch
import torch.nn.functional as F
bs = 8
d_proj = 64
z_i = torch.randn(bs, d_proj)
z_j = torch.randn(bs, d_proj)
z_i_norm = F.normalize(z_i, dim=1)
z_j_norm = F.normalize(z_j, dim=1)
cosine_num = torch.matmul(z_i, z_j.T)
cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
cosine_similarity = cosine_num / cosine_denom
print(cosine_similarity)
You will see the values in cosine_similarity
are quite large (when it should be bounded between -1 and 1).
Below are two correct ways of computing pairwise cosine similarity:
# F.cosine_similarity is preferred for performance
cosine_similarity = F.cosine_similarity(z_i[:,None], z_j[None,:], dim=2)
# alt version to show how cosine similarity is computed
cosine_similarity = (z_i[:,None] * z_j[None,:]).sum(-1) / (torch.norm(z_i, dim=-1)*torch.norm(z_j, dim=-1))
You also have errors in your cross entropy implementation. For example you shouldn't zero the diagonal values of the denominator, and denominator = torch.exp(torch.sum(cosine_similarity, dim=1))
should instead be denominator = torch.exp(cosine_similarity / temperature).sum(dim=1)
(include temperature scaling, sum after exp rather than before).
Overall, you should use F.cross_entropy
instead of manually computing the log-exp values - this is much more numerically stable.
cosine_similarity = F.cosine_similarity(z_i[:,None], z_j[None,:], dim=2)
labels = torch.arange(cosine_similarity.shape[0], device=cosine_similarity.device)
loss = F.cross_entropy(cosine_similarity/temperature, labels)