pythonmachine-learningpytorchnumerical-methodsmxnet

How is log_softmax() implemented to compute its value (and gradient) with better speed and numerical stability?


Both MXNet and PyTorch provide special implementation for computing log(softmax()), which is faster and numerically more stable. However, I cannot find the actual Python implementation for this function, log_softmax(), in either package.

Can anyone explain how this is implemented, or better, point me to the relevant source code?


Solution

  • >>> x = np.array([1, -10, 1000])
    >>> np.exp(x) / np.exp(x).sum()
    RuntimeWarning: overflow encountered in exp
    RuntimeWarning: invalid value encountered in true_divide
    Out[4]: array([ 0.,  0., nan])
    

    There are 2 methods to avoid the numerical error while compute the softmax:

    enter image description here

    def exp_normalize(x):
        b = x.max()
        y = np.exp(x - b)
        return y / y.sum()
    
    >>> exp_normalize(x)
    array([0., 0., 1.])
    

    enter image description here

    def log_softmax(x):
        c = x.max()
        logsumexp = np.log(np.exp(x - c).sum())
        return x - c - logsumexp
    
    

    Please note that, a reasonable choice for both b, c in above formula is max(x). With this choice, overflow due to exp is impossible. The largest number exponentiated after shifting is 0.