pythonmathnumericexponent

Python instability due to exponentials


Setup: I have the following function in python, where x can get very large:

import numpy as np

def function(x, pi):
  d = len(pi)
  output = 0
  for r in range(d):
    output += pi[r] * np.exp(-x)
  return output

Input description: x can be very large causing np.exp(-x) to evaluate to zero which results in the entire function being zero, and pi is just a vector of probabilities (e.g., [0.5, 0.5]).

Question: Is there a more stable way to implement this function such that it wouldn't lead to the output being zero? Thanks.

Edit: I have decided to give more details since it was asked in the comments. The entire function is

def entire_function(x_array, pi, r):
  d = len(pi)
  numerator = np.exp(-x_array[r])
  denominator = 0
  for r_prime in range(d):
    denominator += pi[r_prime] * np.exp(-x_array[r_prime])
  return numerator / denominator 

Even trying to use np.log doesn't really help. For example:

a = np.array([np.exp(-900), np.exp(-800)])
print(np.log(a[0]+a[1]))

This gives me -Inf. The summation in the denominator is the nasty part that is giving me trouble since it is preventing me from accessing the exponents (to make the computation more numerically stable). I guess this issue is similar to the logsumexp examples in machine learning with the extra pi[r] factors in front.


Solution

  • Note that in general we have:

    pex = elog(p) ex = elog(p) + x

    Using this we can apply the log-sum-exp trick you linked

    import numpy as np
    
    xs = np.array([700, 900])
    ps = np.array([0.6, 0.4])
    
    def original(xs, ps, r):
        ex = np.exp(-xs) 
        return ex[r] / (ps*ex).sum()
    
    def log_sum_exp(x):
        c = x.max()
        return c + np.log(np.sum(np.exp(x - c)))
    
    def adjusted(xs, ps, r):
        return np.exp(-xs[r] - log_sum_exp(-xs + np.log(ps)))
    

    Which we can check with

    def check(xs, ps, r):
        # calculate to 1,000 decinal places to check result against
        from decimal import Decimal, getcontext
        getcontext().prec = 1000
        ex = [Decimal.exp(-Decimal(float(xi))) for xi in xs]
        return ex[r] / sum(Decimal(float(pi))*ei for pi,ei in zip(ps, ex))
    
    print([adjusted(xs, ps, i) for i in range(2)])       # [1.6666666666666516, 2.306494211227875e-87]
    print([float(check(xs, ps, i)) for i in range(2)])   # [1.6666666666666667, 2.306494211227896e-87]