cprecisionavxavx2fma

Why does '_mm256_fmadd_ps' cause precision loss?


I use _mm256_fmadd_ps to perform a * b and accumulate it to the result c, like c=a*b+c. It is found that under certain circumstances, fmadd operations will cause precision loss compared with those that mul first and then add, especially when c already has a non-zero value.

test code:


#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
#include <intrin.h>

static inline void multiply_scalar_and_accumulate_generic(float *out, const float *in, const float scalar,
                                                          unsigned int cnt)
{
    const float *aPtr = (float *)in;
    float       *cPtr = (float *)out;

    while (cnt >= 8) {
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
        cnt -= 8;
    }

    while (cnt-- > 0) {
        *cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
    }

    return;
}

static inline void multiply_scalar_and_accumulate_avx(float *out, const float *in, const float scalar, unsigned int cnt)
{
    unsigned int idx = 0;

    const float *aPtr = (float *)in;
    float       *cPtr = (float *)out;

    __m256       aVal;
    __m256       cVal;
    const __m256 bVal = _mm256_set1_ps(scalar);

    for (; idx < cnt; idx += 8)
    {
        aVal = _mm256_loadu_ps(aPtr);
        cVal = _mm256_loadu_ps(cPtr);

        cVal = _mm256_add_ps(cVal, _mm256_mul_ps(aVal, bVal));

        _mm256_storeu_ps(cPtr, cVal);

        aPtr += 8;
        cPtr += 8;
    }

    for (; idx < cnt; idx++)
    {
        *cPtr = (*aPtr++) * scalar + (*cPtr);
        cPtr++;
    }
    return;
}

static inline void multiply_scalar_and_accumulate_avx_fma(float *out, const float *in, const float scalar,
                                                          unsigned int cnt)
{
    unsigned int idx = 0;

    const float *aPtr = (float *)in;
    float       *cPtr = (float *)out;

    __m256       aVal;
    __m256       cVal;
    const __m256 bVal = _mm256_set1_ps(scalar);

    for (; idx < cnt; idx += 8)
    {
        aVal = _mm256_loadu_ps(aPtr);
        cVal = _mm256_loadu_ps(cPtr);

        cVal = _mm256_fmadd_ps(aVal, bVal, cVal);

        _mm256_storeu_ps(cPtr, cVal);

        aPtr += 8;
        cPtr += 8;
    }

    for (; idx < cnt; idx++)
    {
        *cPtr = (*aPtr++) * scalar + (*cPtr);
        cPtr++;
    }
    return;
}

int main(void)
{
#define TEST_COUNT (0x4000)
    float *in          = NULL;
    float *ref         = NULL;
    float *out_avx     = NULL;
    float *out_avx_fma = NULL;

    in          = (float *)malloc(sizeof(float) * TEST_COUNT);
    ref         = (float *)malloc(sizeof(float) * TEST_COUNT);
    out_avx     = (float *)malloc(sizeof(float) * TEST_COUNT);
    out_avx_fma = (float *)malloc(sizeof(float) * TEST_COUNT);
    if ((in == NULL) || (ref == NULL) || (out_avx == NULL) || (out_avx_fma == NULL))
    {
        printf("alloc failed\n");
        return 0;
    }

    printf("test start\n");

    float scalar       = 0;
    float diff_avx     = 0;
    float diff_avx_fma = 0;

    const float TOLERANCE = 1e-3f;

    srand(time(0));

    memset(ref, 0x0, sizeof(float) * TEST_COUNT);
    memset(out_avx, 0x0, sizeof(float) * TEST_COUNT);
    memset(out_avx_fma, 0x0, sizeof(float) * TEST_COUNT);

    for (int i = 0; i < TEST_COUNT; i++)
    {
        in[i] = ((float)rand()) / ((float)rand()) * 10.0f;
    }

    scalar = ((float)rand()) / ((float)rand()) * 10.0f;
    multiply_scalar_and_accumulate_generic(ref, in, scalar, TEST_COUNT);
    multiply_scalar_and_accumulate_avx(out_avx, in, scalar, TEST_COUNT);
    multiply_scalar_and_accumulate_avx_fma(out_avx_fma, in, scalar, TEST_COUNT);

#define MAKE_ACCUMULATE (1)

#if MAKE_ACCUMULATE
    scalar = ((float)rand()) / ((float)rand()) * 10.0f;
    multiply_scalar_and_accumulate_generic(ref, in, scalar, TEST_COUNT);
    multiply_scalar_and_accumulate_avx(out_avx, in, scalar, TEST_COUNT);
    multiply_scalar_and_accumulate_avx_fma(out_avx_fma, in, scalar, TEST_COUNT);
#endif

    for (int i = 0; i < TEST_COUNT; i++)
    {
        diff_avx     = fabsf(out_avx[i] - ref[i]);
        diff_avx_fma = fabsf(out_avx_fma[i] - ref[i]);
        if (diff_avx > TOLERANCE)
        {
            printf("[Err AVX] pos:%06d, %20.4f != %20.4f, avx_diff:%.4f, avx_fma_diff:%.4f\n", i, out_avx[i], ref[i],
                   diff_avx, diff_avx_fma);
        }
        if (diff_avx_fma > TOLERANCE)
        {
            printf("[Err AVX_FMA] pos:%06d, %20.4f != %20.4f, avx_fma_diff:%.4f, avx_diff:%.4f\n", i, out_avx_fma[i], ref[i],
                   diff_avx_fma, diff_avx);
        }
    }

    printf("test end\n");

    return 0;
}

result:

MAKE_ACCUMULATE == 0

MAKE_ACCUMULATE == 1


Solution

  • Your code does not show any “precision loss”1 with the FMA. The FMA provides better results, and you are comparing them to poorer results provided by separate multiplication and addition.

    Your code compares the results computed by multiply_scalar_and_accumulate_avx and by multiply_scalar_and_accumulate_avx_fma with reference results computed using plain C * and +, thus using separate addition and multiplication. In these reference results, the * and the + are each computed (in your C implementation) separately using float precision. These operations have built-in rounding to float precision: The result of * is the equivalent of the exact real-number arithmetic product rounded to float precision, and the result of + is the equivalent of the exact real-number arithmetic sum rounded to float precision.

    In contrast, the _mm256_fmadd_ps instruction has a single rounding. The result it provides for operands a, b, and c is the equivalent of the exact real-number arithmetic expression ab+c rounded to float precision. This result is always the value representable in float that is closest to ab+c, so it is always the best result possible.

    It is your reference results that are inaccurate. Because they have two roundings, their results sometimes differ from the better result provided by FMA.

    If you change multiply_scalar_and_accumulate_generic to use double arithmetic (which you can do simply by changing its parameter const float scalar to const double scalar) or you change it to use the fma feature in the standard C library, your program will report that the AVX results differ from the reference values, rather than the AVX_FMA results differ.

    Footnote

    1 This is a misnomer. Precision is the fineness with which numbers are represented. All numbers in this code use IEEE-754 binary32, which has 24-bit significands, so they all have the same precision. Your concern is actually about accuracy, which is how well a number represents some ideal target value.