coptimizationvectorizationssesimd

Fastest Implementation of the Natural Exponential Function Using SSE


I'm looking for an approximation of the natural exponential function operating on SSE element. Namely - __m128 exp( __m128 x ).

I have an implementation which is quick but seems to be very low in accuracy:

static inline __m128 FastExpSse(__m128 x)
{
    __m128 a = _mm_set1_ps(12102203.2f); // (1 << 23) / ln(2)
    __m128i b = _mm_set1_epi32(127 * (1 << 23) - 486411);
    __m128  m87 = _mm_set1_ps(-87);
    // fast exponential function, x should be in [-87, 87]
    __m128 mask = _mm_cmpge_ps(x, m87);

    __m128i tmp = _mm_add_epi32(_mm_cvtps_epi32(_mm_mul_ps(a, x)), b);
    return _mm_and_ps(_mm_castsi128_ps(tmp), mask);
}

Could anybody have an implementation with better accuracy yet as fast (Or faster)?

I'd be happy if it is written in C Style.

Thank You.


Solution

  • The C code below is a translation into SSE intrinsics of an algorithm I used in a previous answer to a similar question.

    The basic idea is to transform the computation of the standard exponential function into computation of a power of 2: expf (x) = exp2f (x / logf (2.0f)) = exp2f (x * 1.44269504). We split t = x * 1.44269504 into an integer i and a fraction f, such that t = i + f and 0 <= f <= 1. We can now compute 2f with a polynomial approximation, then scale the result by 2i by adding i to the exponent field of the single-precision floating-point result.

    One problem that exists with an SSE implementation is that we want to compute i = floorf (t), but there is no fast way to compute the floor() function. However, we observe that for positive numbers, floor(x) == trunc(x), and that for negative numbers, floor(x) == trunc(x) - 1, except when x is a negative integer. However, since the core approximation can handle an f value of 1.0f, using the approximation for negative arguments is harmless. SSE provides an instruction to convert single-precision floating point operands to integers with truncation, so this solution is efficient.

    Peter Cordes points out that SSE4.1 supports a fast floor function _mm_floor_ps(), so a variant using SSE4.1 is also shown below. Not all toolchains automatically predefine the macro __SSE4_1__ when SSE 4.1 code generation is enabled, but gcc does.

    Compiler Explorer (Godbolt) shows that gcc 7.2 compiles the code below into sixteen instructions for plain SSE and twelve instructions for SSE 4.1.

    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    #include <math.h>
    #include <emmintrin.h>
    #ifdef __SSE4_1__
    #include <smmintrin.h>
    #endif
    
    /* max. rel. error = 1.72863156e-3 on [-87.33654, 88.72283] */
    __m128 fast_exp_sse (__m128 x)
    {
        __m128 t, f, e, p, r;
        __m128i i, j;
        __m128 l2e = _mm_set1_ps (1.442695041f);  /* log2(e) */
        __m128 c0  = _mm_set1_ps (0.3371894346f);
        __m128 c1  = _mm_set1_ps (0.657636276f);
        __m128 c2  = _mm_set1_ps (1.00172476f);
    
        /* exp(x) = 2^i * 2^f; i = floor (log2(e) * x), 0 <= f <= 1 */   
        t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
    #ifdef __SSE4_1__
        e = _mm_floor_ps (t);                /* floor(t) */
        i = _mm_cvtps_epi32 (e);             /* (int)floor(t) */
    #else /* __SSE4_1__*/
        i = _mm_cvttps_epi32 (t);            /* i = (int)t */
        j = _mm_srli_epi32 (_mm_castps_si128 (x), 31); /* signbit(t) */
        i = _mm_sub_epi32 (i, j);            /* (int)t - signbit(t) */
        e = _mm_cvtepi32_ps (i);             /* floor(t) ~= (int)t - signbit(t) */
    #endif /* __SSE4_1__*/
        f = _mm_sub_ps (t, e);               /* f = t - floor(t) */
        p = c0;                              /* c0 */
        p = _mm_mul_ps (p, f);               /* c0 * f */
        p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
        p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
        p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
        j = _mm_slli_epi32 (i, 23);          /* i << 23 */
        r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
        return r;
    }
    
    int main (void)
    {
        union {
            float f[4];
            unsigned int i[4];
        } arg, res;
        double relerr, maxrelerr = 0.0;
        int i, j;
        __m128 x, y;
    
        float start[2] = {-0.0f, 0.0f};
        float finish[2] = {-87.33654f, 88.72283f};
    
        for (i = 0; i < 2; i++) {
    
            arg.f[0] = start[i];
            arg.i[1] = arg.i[0] + 1;
            arg.i[2] = arg.i[0] + 2;
            arg.i[3] = arg.i[0] + 3;
            do {
                memcpy (&x, &arg, sizeof(x));
                y = fast_exp_sse (x);
                memcpy (&res, &y, sizeof(y));
                for (j = 0; j < 4; j++) {
                    double ref = exp ((double)arg.f[j]);
                    relerr = fabs ((res.f[j] - ref) / ref);
                    if (relerr > maxrelerr) {
                        printf ("arg=% 15.8e  res=%15.8e  ref=%15.8e  err=%15.8e\n", 
                                arg.f[j], res.f[j], ref, relerr);
                        maxrelerr = relerr;
                    }
                }   
                arg.i[0] += 4;
                arg.i[1] += 4;
                arg.i[2] += 4;
                arg.i[3] += 4;
            } while (fabsf (arg.f[3]) < fabsf (finish[i]));
        }
        printf ("maximum relative errror = %15.8e\n", maxrelerr);
        return EXIT_SUCCESS;
    }
    

    An alternative design for fast_sse_exp() extracts the integer portion of the adjusted argument x / log(2) in round-to-nearest mode, using the well-known technique of adding the "magic" conversion constant 1.5 * 223 to force rounding in the correct bit position, then subtracting out the same number again. This requires that the SSE rounding mode in effect during the addition is "round to nearest or even", which is the default. wim pointed out in comments that some compilers may optimize out the addition and subtraction of the conversion constant cvt as redundant when aggressive optimization is used, interfering with the functionality of this code sequence, so it is recommended to inspect the machine code generated. The approximation interval for computation of 2f is now centered around zero, since -0.5 <= f <= 0.5, requiring a different core approximation.

    /* max. rel. error <= 1.72860465e-3 on [-87.33654, 88.72283] */
    __m128 fast_exp_sse (__m128 x)
    {
        __m128 t, f, p, r;
        __m128i i, j;
    
        const __m128 l2e = _mm_set1_ps (1.442695041f); /* log2(e) */
        const __m128 cvt = _mm_set1_ps (12582912.0f);  /* 1.5 * (1 << 23) */
        const __m128 c0 =  _mm_set1_ps (0.238428936f);
        const __m128 c1 =  _mm_set1_ps (0.703448006f);
        const __m128 c2 =  _mm_set1_ps (1.000443142f);
    
        /* exp(x) = 2^i * 2^f; i = rint (log2(e) * x), -0.5 <= f <= 0.5 */
        t = _mm_mul_ps (x, l2e);             /* t = log2(e) * x */
        r = _mm_sub_ps (_mm_add_ps (t, cvt), cvt); /* r = rint (t) */
        f = _mm_sub_ps (t, r);               /* f = t - rint (t) */
        i = _mm_cvtps_epi32 (t);             /* i = (int)t */
        p = c0;                              /* c0 */
        p = _mm_mul_ps (p, f);               /* c0 * f */
        p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
        p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
        p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= exp2(f) */
        j = _mm_slli_epi32 (i, 23);          /* i << 23 */
        r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
        return r;
    }
    

    The algorithm for the code in the question appears to be taken from the work of Nicol N. Schraudolph, which cleverly exploits the semi-logarithmic nature of IEEE-754 binary floating-point formats:

    N. N. Schraudolph. "A fast, compact approximation of the exponential function." Neural Computation, 11(4), May 1999, pp.853-862.

    After removal of the argument clamping code, it reduces to just three SSE instructions. The "magical" correction constant 486411 is not optimal for minimizing maximum relative error over the entire input domain. Based on simple binary search, the value 298765 seems to be superior, reducing maximum relative error for FastExpSse() to 3.56e-2 vs. maximum relative error of 1.73e-3 for fast_exp_sse().

    /* max. rel. error = 3.55959567e-2 on [-87.33654, 88.72283] */
    __m128 FastExpSse (__m128 x)
    {
        __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
        __m128i b = _mm_set1_epi32 (127 * (1 << 23) - 298765);
        __m128i t = _mm_add_epi32 (_mm_cvtps_epi32 (_mm_mul_ps (a, x)), b);
        return _mm_castsi128_ps (t);
    }
    

    Schraudolph's algorithm basically uses the linear approximation 2f ~= 1.0 + f for f in [0,1], and its accuracy could be improved by adding a quadratic term. The clever part of Schraudolph's approach is computing 2i * 2f without explicitly separating the integer portion i = floor(x * 1.44269504) from the fraction. I see no way to extend that trick to a quadratic approximation, but one can certainly combine the floor() computation from Schraudolph with the quadratic approximation used above:

    /* max. rel. error <= 1.72886892e-3 on [-87.33654, 88.72283] */
    __m128 fast_exp_sse (__m128 x)
    {
        __m128 f, p, r;
        __m128i t, j;
        const __m128 a = _mm_set1_ps (12102203.0f); /* (1 << 23) / log(2) */
        const __m128i m = _mm_set1_epi32 (0xff800000); /* mask for integer bits */
        const __m128 ttm23 = _mm_set1_ps (1.1920929e-7f); /* exp2(-23) */
        const __m128 c0 = _mm_set1_ps (0.3371894346f);
        const __m128 c1 = _mm_set1_ps (0.657636276f);
        const __m128 c2 = _mm_set1_ps (1.00172476f);
    
        t = _mm_cvtps_epi32 (_mm_mul_ps (a, x));
        j = _mm_and_si128 (t, m);            /* j = (int)(floor (x/log(2))) << 23 */
        t = _mm_sub_epi32 (t, j);
        f = _mm_mul_ps (ttm23, _mm_cvtepi32_ps (t)); /* f = (x/log(2)) - floor (x/log(2)) */
        p = c0;                              /* c0 */
        p = _mm_mul_ps (p, f);               /* c0 * f */
        p = _mm_add_ps (p, c1);              /* c0 * f + c1 */
        p = _mm_mul_ps (p, f);               /* (c0 * f + c1) * f */
        p = _mm_add_ps (p, c2);              /* p = (c0 * f + c1) * f + c2 ~= 2^f */
        r = _mm_castsi128_ps (_mm_add_epi32 (j, _mm_castps_si128 (p))); /* r = p * 2^i*/
        return r;
    }