I am working on a network that uses a LSTM along with MDNs to predict some distributions. The loss function I use for these MDNs involve trying to fit my target data to the predicted distributions. I am trying to compute the log-sum-exp for the log_probs of these target data to compute the loss. When I use standard log-sum-exp, I get reasonable initial loss values (around 50-70) even though later it encounters some NaNs and breaks. Based on what I have read online, a numerically stable version of log-sum-exp is required to avoid this problem. However as soon I use the stable version, my loss values shoot up to the order of 15-20k. They do come down upon training but eventually they also lead to NaNs.
NOTE : I did not use the logsumexp function in PyTorch, since I needed to have a weighted summation based on my mixture components.
def log_sum_exp(self,value, weights, dim=None):
eps = 1e-20
m, idx = torch.max(value, dim=dim, keepdim=True)
return m.squeeze(dim) + torch.log(torch.sum(torch.exp(value-m)*(weights.unsqueeze(2)),
dim=dim) + eps)
def mdn_loss(self, pi, sigma, mu, target):
eps = 1e-20
target = target.unsqueeze(1)
m = torch.distributions.Normal(loc=mu, scale=sigma)
probs = m.log_prob(target)
# Size of probs is batch_size x num_mixtures x num_out_features
# Size of pi is batch_size x num_mixtures
loss = -self.log_sum_exp(probs, pi, dim=1)
return loss.mean()
Upon adding anomaly_detection, the NaNs seem to occur at : probs = m.log_prob(target)
Seeing these huge initial loss values just by moving to the numerically stable version have led me to believe I have some bug in my current implementation. Any help please.
Issue resolved. My targets had some large values which were leading to overflow calculations when log_probs was calculated for these values. Removed some outlandish data points and normalised the data, loss immediately came down.