pythondeep-learningpytorchgmm

Moving to numerically stable log-sum-exp leads to extremely large loss values


I am working on a network that uses a LSTM along with MDNs to predict some distributions. The loss function I use for these MDNs involve trying to fit my target data to the predicted distributions. I am trying to compute the log-sum-exp for the log_probs of these target data to compute the loss. When I use standard log-sum-exp, I get reasonable initial loss values (around 50-70) even though later it encounters some NaNs and breaks. Based on what I have read online, a numerically stable version of log-sum-exp is required to avoid this problem. However as soon I use the stable version, my loss values shoot up to the order of 15-20k. They do come down upon training but eventually they also lead to NaNs.

NOTE : I did not use the logsumexp function in PyTorch, since I needed to have a weighted summation based on my mixture components.

def log_sum_exp(self,value, weights, dim=None):
        eps = 1e-20
        m, idx = torch.max(value, dim=dim, keepdim=True)
        return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)),
                                       dim=dim) + eps)

def mdn_loss(self, pi, sigma, mu, target):
        eps = 1e-20
        target = target.unsqueeze(1)
        m = torch.distributions.Normal(loc=mu, scale=sigma)
        probs = m.log_prob(target)
        # Size of probs is batch_size x num_mixtures x num_out_features
        # Size of pi is batch_size x num_mixtures 
        loss = -self.log_sum_exp(probs, pi, dim=1)
        return loss.mean()

Upon adding anomaly_detection, the NaNs seem to occur at : probs = m.log_prob(target)

Seeing these huge initial loss values just by moving to the numerically stable version have led me to believe I have some bug in my current implementation. Any help please.


Solution

  • Issue resolved. My targets had some large values which were leading to overflow calculations when log_probs was calculated for these values. Removed some outlandish data points and normalised the data, loss immediately came down.