deep-learningpytorchartificial-intelligenceloss-functionself-supervised-learning

Contrastive Loss from Scratch


I am trying to implement/learn how to implement contrastive loss. Currently my gradients are exploding into infinity and I think I must have misimplemented something. I was wondering if someone could take a look at my loss function and tell if they see an error

class ContrastiveLoss(nn.Module):
    def __init__(self, temperature=0.5):
        super(ContrastiveLoss, self).__init__()
        self.temperature = temperature

    def forward(self, projections_1, projections_2):
        z_i = projections_1
        z_j = projections_2
        z_i_norm = F.normalize(z_i, dim=1)
        z_j_norm = F.normalize(z_j, dim=1)
        cosine_num = torch.matmul(z_i, z_j.T)
        cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
        cosine_similarity = cosine_num / cosine_denom

        numerator = torch.exp(torch.diag(cosine_similarity) / self.temperature)

        denominator = cosine_similarity
        diagonal_indices = torch.arange(denominator.size(0))
        denominator[diagonal_indices, diagonal_indices] = 0
        denominator = torch.exp(torch.sum(cosine_similarity, dim=1))
        loss = -torch.log(numerator / denominator).sum()
        return loss

Solution

  • Your implementation of cosine similarity is wrong. You can see this by inspecting the values of the cosine similarity matrix. Run the following:

    import torch
    import torch.nn.functional as F
    
    bs = 8
    d_proj = 64
    z_i = torch.randn(bs, d_proj)
    z_j = torch.randn(bs, d_proj)
    
    z_i_norm = F.normalize(z_i, dim=1)
    z_j_norm = F.normalize(z_j, dim=1)
    
    cosine_num = torch.matmul(z_i, z_j.T)
    cosine_denom = torch.matmul(z_i_norm, z_j_norm.T)
    cosine_similarity = cosine_num / cosine_denom
    
    print(cosine_similarity)
    

    You will see the values in cosine_similarity are quite large (when it should be bounded between -1 and 1).

    Below are two correct ways of computing pairwise cosine similarity:

    # F.cosine_similarity is preferred for performance
    cosine_similarity = F.cosine_similarity(z_i[:,None], z_j[None,:], dim=2)
    
    # alt version to show how cosine similarity is computed 
    cosine_similarity = (z_i[:,None] * z_j[None,:]).sum(-1) / (torch.norm(z_i, dim=-1)*torch.norm(z_j, dim=-1))
    

    You also have errors in your cross entropy implementation. For example you shouldn't zero the diagonal values of the denominator, and denominator = torch.exp(torch.sum(cosine_similarity, dim=1)) should instead be denominator = torch.exp(cosine_similarity / temperature).sum(dim=1) (include temperature scaling, sum after exp rather than before).

    Overall, you should use F.cross_entropy instead of manually computing the log-exp values - this is much more numerically stable.

    cosine_similarity = F.cosine_similarity(z_i[:,None], z_j[None,:], dim=2)
    labels = torch.arange(cosine_similarity.shape[0], device=cosine_similarity.device)
    loss = F.cross_entropy(cosine_similarity/temperature, labels)