Various frameworks and libraries (such as PyTorch and SciPy) provide special implementation for computing log(softmax())
, which is faster and numerically more stable. However, I cannot find the actual Python implementation for this function, log_softmax()
, in either of these packages.
Can anyone explain how this is implemented, or better, point me to the relevant source code?
>>> x = np.array([1, -10, 1000])
>>> np.exp(x) / np.exp(x).sum()
RuntimeWarning: overflow encountered in exp
RuntimeWarning: invalid value encountered in true_divide
Out[4]: array([ 0., 0., nan])
There are 2 methods to avoid the numerical error while compute the softmax:
def exp_normalize(x):
b = x.max()
y = np.exp(x - b)
return y / y.sum()
>>> exp_normalize(x)
array([0., 0., 1.])
def log_softmax(x):
c = x.max()
logsumexp = np.log(np.exp(x - c).sum())
return x - c - logsumexp
Please note that, a reasonable choice for both b, c in above formula is max(x). With this choice, overflow due to exp is impossible. The largest number exponentiated after shifting is 0.