I use _mm256_fmadd_ps
to perform a * b
and accumulate it to the result c
, like c=a*b+c
. It is found that under certain circumstances, fmadd
operations will cause precision loss compared with those that mul
first and then add
, especially when c
already has a non-zero value.
test code:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <math.h>
#include <time.h>
#include <intrin.h>
static inline void multiply_scalar_and_accumulate_generic(float *out, const float *in, const float scalar,
unsigned int cnt)
{
const float *aPtr = (float *)in;
float *cPtr = (float *)out;
while (cnt >= 8) {
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
cnt -= 8;
}
while (cnt-- > 0) {
*cPtr = (*aPtr++) * scalar + (*cPtr); cPtr++;
}
return;
}
static inline void multiply_scalar_and_accumulate_avx(float *out, const float *in, const float scalar, unsigned int cnt)
{
unsigned int idx = 0;
const float *aPtr = (float *)in;
float *cPtr = (float *)out;
__m256 aVal;
__m256 cVal;
const __m256 bVal = _mm256_set1_ps(scalar);
for (; idx < cnt; idx += 8)
{
aVal = _mm256_loadu_ps(aPtr);
cVal = _mm256_loadu_ps(cPtr);
cVal = _mm256_add_ps(cVal, _mm256_mul_ps(aVal, bVal));
_mm256_storeu_ps(cPtr, cVal);
aPtr += 8;
cPtr += 8;
}
for (; idx < cnt; idx++)
{
*cPtr = (*aPtr++) * scalar + (*cPtr);
cPtr++;
}
return;
}
static inline void multiply_scalar_and_accumulate_avx_fma(float *out, const float *in, const float scalar,
unsigned int cnt)
{
unsigned int idx = 0;
const float *aPtr = (float *)in;
float *cPtr = (float *)out;
__m256 aVal;
__m256 cVal;
const __m256 bVal = _mm256_set1_ps(scalar);
for (; idx < cnt; idx += 8)
{
aVal = _mm256_loadu_ps(aPtr);
cVal = _mm256_loadu_ps(cPtr);
cVal = _mm256_fmadd_ps(aVal, bVal, cVal);
_mm256_storeu_ps(cPtr, cVal);
aPtr += 8;
cPtr += 8;
}
for (; idx < cnt; idx++)
{
*cPtr = (*aPtr++) * scalar + (*cPtr);
cPtr++;
}
return;
}
int main(void)
{
#define TEST_COUNT (0x4000)
float *in = NULL;
float *ref = NULL;
float *out_avx = NULL;
float *out_avx_fma = NULL;
in = (float *)malloc(sizeof(float) * TEST_COUNT);
ref = (float *)malloc(sizeof(float) * TEST_COUNT);
out_avx = (float *)malloc(sizeof(float) * TEST_COUNT);
out_avx_fma = (float *)malloc(sizeof(float) * TEST_COUNT);
if ((in == NULL) || (ref == NULL) || (out_avx == NULL) || (out_avx_fma == NULL))
{
printf("alloc failed\n");
return 0;
}
printf("test start\n");
float scalar = 0;
float diff_avx = 0;
float diff_avx_fma = 0;
const float TOLERANCE = 1e-3f;
srand(time(0));
memset(ref, 0x0, sizeof(float) * TEST_COUNT);
memset(out_avx, 0x0, sizeof(float) * TEST_COUNT);
memset(out_avx_fma, 0x0, sizeof(float) * TEST_COUNT);
for (int i = 0; i < TEST_COUNT; i++)
{
in[i] = ((float)rand()) / ((float)rand()) * 10.0f;
}
scalar = ((float)rand()) / ((float)rand()) * 10.0f;
multiply_scalar_and_accumulate_generic(ref, in, scalar, TEST_COUNT);
multiply_scalar_and_accumulate_avx(out_avx, in, scalar, TEST_COUNT);
multiply_scalar_and_accumulate_avx_fma(out_avx_fma, in, scalar, TEST_COUNT);
#define MAKE_ACCUMULATE (1)
#if MAKE_ACCUMULATE
scalar = ((float)rand()) / ((float)rand()) * 10.0f;
multiply_scalar_and_accumulate_generic(ref, in, scalar, TEST_COUNT);
multiply_scalar_and_accumulate_avx(out_avx, in, scalar, TEST_COUNT);
multiply_scalar_and_accumulate_avx_fma(out_avx_fma, in, scalar, TEST_COUNT);
#endif
for (int i = 0; i < TEST_COUNT; i++)
{
diff_avx = fabsf(out_avx[i] - ref[i]);
diff_avx_fma = fabsf(out_avx_fma[i] - ref[i]);
if (diff_avx > TOLERANCE)
{
printf("[Err AVX] pos:%06d, %20.4f != %20.4f, avx_diff:%.4f, avx_fma_diff:%.4f\n", i, out_avx[i], ref[i],
diff_avx, diff_avx_fma);
}
if (diff_avx_fma > TOLERANCE)
{
printf("[Err AVX_FMA] pos:%06d, %20.4f != %20.4f, avx_fma_diff:%.4f, avx_diff:%.4f\n", i, out_avx_fma[i], ref[i],
diff_avx_fma, diff_avx);
}
}
printf("test end\n");
return 0;
}
result:
Your code does not show any “precision loss”1 with the FMA. The FMA provides better results, and you are comparing them to poorer results provided by separate multiplication and addition.
Your code compares the results computed by multiply_scalar_and_accumulate_avx
and by multiply_scalar_and_accumulate_avx_fma
with reference results computed using plain C *
and +
, thus using separate addition and multiplication. In these reference results, the *
and the +
are each computed (in your C implementation) separately using float
precision. These operations have built-in rounding to float
precision: The result of *
is the equivalent of the exact real-number arithmetic product rounded to float
precision, and the result of +
is the equivalent of the exact real-number arithmetic sum rounded to float
precision.
In contrast, the _mm256_fmadd_ps
instruction has a single rounding. The result it provides for operands a, b, and c is the equivalent of the exact real-number arithmetic expression a•b+c rounded to float
precision. This result is always the value representable in float
that is closest to a•b+c, so it is always the best result possible.
It is your reference results that are inaccurate. Because they have two roundings, their results sometimes differ from the better result provided by FMA.
If you change multiply_scalar_and_accumulate_generic
to use double
arithmetic (which you can do simply by changing its parameter const float scalar
to const double scalar
) or you change it to use the fma
feature in the standard C library, your program will report that the AVX
results differ from the reference values, rather than the AVX_FMA
results differ.
1 This is a misnomer. Precision is the fineness with which numbers are represented. All numbers in this code use IEEE-754 binary32, which has 24-bit significands, so they all have the same precision. Your concern is actually about accuracy, which is how well a number represents some ideal target value.