Setup: I have the following function in python, where x can get very large:
import numpy as np
def function(x, pi):
d = len(pi)
output = 0
for r in range(d):
output += pi[r] * np.exp(-x)
return output
Input description: x can be very large causing np.exp(-x) to evaluate to zero which results in the entire function being zero, and pi is just a vector of probabilities (e.g., [0.5, 0.5]).
Question: Is there a more stable way to implement this function such that it wouldn't lead to the output being zero? Thanks.
Edit: I have decided to give more details since it was asked in the comments. The entire function is
def entire_function(x_array, pi, r):
d = len(pi)
numerator = np.exp(-x_array[r])
denominator = 0
for r_prime in range(d):
denominator += pi[r_prime] * np.exp(-x_array[r_prime])
return numerator / denominator
Even trying to use np.log doesn't really help. For example:
a = np.array([np.exp(-900), np.exp(-800)])
print(np.log(a[0]+a[1]))
This gives me -Inf. The summation in the denominator is the nasty part that is giving me trouble since it is preventing me from accessing the exponents (to make the computation more numerically stable). I guess this issue is similar to the logsumexp examples in machine learning with the extra pi[r] factors in front.
Note that in general we have:
pex = elog(p) ex = elog(p) + x
Using this we can apply the log-sum-exp trick you linked
import numpy as np
xs = np.array([700, 900])
ps = np.array([0.6, 0.4])
def original(xs, ps, r):
ex = np.exp(-xs)
return ex[r] / (ps*ex).sum()
def log_sum_exp(x):
c = x.max()
return c + np.log(np.sum(np.exp(x - c)))
def adjusted(xs, ps, r):
return np.exp(-xs[r] - log_sum_exp(-xs + np.log(ps)))
Which we can check with
def check(xs, ps, r):
# calculate to 1,000 decinal places to check result against
from decimal import Decimal, getcontext
getcontext().prec = 1000
ex = [Decimal.exp(-Decimal(float(xi))) for xi in xs]
return ex[r] / sum(Decimal(float(pi))*ei for pi,ei in zip(ps, ex))
print([adjusted(xs, ps, i) for i in range(2)]) # [1.6666666666666516, 2.306494211227875e-87]
print([float(check(xs, ps, i)) for i in range(2)]) # [1.6666666666666667, 2.306494211227896e-87]