c++simdavx2dot-productfma

AVX2: Computing dot product of 512 float arrays


I will preface this by saying that I am a complete beginner at SIMD intrinsics.

Essentially, I have a CPU which supports the AVX2 instrinsic (Intel(R) Core(TM) i5-7500T CPU @ 2.70GHz). I would like to know the fastest way to compute the dot product of two std::vector<float> of size 512.

I have done some digging online and found this and this, and this stack overflow question suggests using the following function __m256 _mm256_dp_ps(__m256 m1, __m256 m2, const int mask);, However, these all suggest different ways of performing the dot product I am not sure what is the correct (and fastest) way to do it.

In particular, I am looking for the fastest way to perform dot product for a vector of size 512 (because I know the vector size effects the implementation).

Thank you for your help

Edit 1: I am also a little confused about the -mavx2 gcc flag. If I use these AVX2 functions, do I need to add the flag when I compile? Also, is gcc able to do these optimizations for me (say if I use the -OFast gcc flag) if I write a naive dot product implementation?

Edit 2 If anyone has the time and energy, I would very much appreciate if you could write a full implementation. I am sure other beginners would also value this information.


Solution

  • _mm256_dp_ps is only useful for dot-products of 2 to 4 elements; for longer vectors use vertical SIMD in a loop and reduce to scalar at the end. Using _mm256_dp_ps and _mm256_add_ps in a loop would be much slower.


    GCC and clang require you to enable (with command line options) ISA extensions that you use intrinsics for, unlike MSVC and ICC.


    The code below is probably close to theoretical performance limit of your CPU. Untested.

    Compile it with clang or gcc -O3 -march=native. (Requires at least -mavx -mfma, but -mtune options implied by -march are good, too, and so are the other -mpopcnt and other things arch=native enables. Tune options are critical to this compiling efficiently for most CPUs with FMA, specifically -mno-avx256-split-unaligned-load: Why doesn't gcc resolve _mm256_loadu_pd as single vmovupd?)

    Or compile it with MSVC -O2 -arch:AVX2

    #include <immintrin.h>
    #include <vector>
    #include <assert.h>
    
    // CPUs support RAM access like this: "ymmword ptr [rax+64]"
    // Using templates with offset int argument to make easier for compiler to emit good code.
    
    // Multiply 8 floats by another 8 floats.
    template<int offsetRegs>
    inline __m256 mul8( const float* p1, const float* p2 )
    {
        constexpr int lanes = offsetRegs * 8;
        const __m256 a = _mm256_loadu_ps( p1 + lanes );
        const __m256 b = _mm256_loadu_ps( p2 + lanes );
        return _mm256_mul_ps( a, b );
    }
    
    // Returns acc + ( p1 * p2 ), for 8-wide float lanes.
    template<int offsetRegs>
    inline __m256 fma8( __m256 acc, const float* p1, const float* p2 )
    {
        constexpr int lanes = offsetRegs * 8;
        const __m256 a = _mm256_loadu_ps( p1 + lanes );
        const __m256 b = _mm256_loadu_ps( p2 + lanes );
        return _mm256_fmadd_ps( a, b, acc );
    }
    
    // Compute dot product of float vectors, using 8-wide FMA instructions.
    float dotProductFma( const std::vector<float>& a, const std::vector<float>& b )
    {
        assert( a.size() == b.size() );
        assert( 0 == ( a.size() % 32 ) );
        if( a.empty() )
            return 0.0f;
    
        const float* p1 = a.data();
        const float* const p1End = p1 + a.size();
        const float* p2 = b.data();
    
        // Process initial 32 values. Nothing to add yet, just multiplying.
        __m256 dot0 = mul8<0>( p1, p2 );
        __m256 dot1 = mul8<1>( p1, p2 );
        __m256 dot2 = mul8<2>( p1, p2 );
        __m256 dot3 = mul8<3>( p1, p2 );
        p1 += 8 * 4;
        p2 += 8 * 4;
    
        // Process the rest of the data.
        // The code uses FMA instructions to multiply + accumulate, consuming 32 values per loop iteration.
        // Unrolling manually for 2 reasons:
        // 1. To reduce data dependencies. With a single register, every loop iteration would depend on the previous result.
        // 2. Unrolled code checks for exit condition 4x less often, therefore more CPU cycles spent computing useful stuff.
        while( p1 < p1End )
        {
            dot0 = fma8<0>( dot0, p1, p2 );
            dot1 = fma8<1>( dot1, p1, p2 );
            dot2 = fma8<2>( dot2, p1, p2 );
            dot3 = fma8<3>( dot3, p1, p2 );
            p1 += 8 * 4;
            p2 += 8 * 4;
        }
    
        // Add 32 values into 8
        const __m256 dot01 = _mm256_add_ps( dot0, dot1 );
        const __m256 dot23 = _mm256_add_ps( dot2, dot3 );
        const __m256 dot0123 = _mm256_add_ps( dot01, dot23 );
        // Add 8 values into 4
        const __m128 r4 = _mm_add_ps( _mm256_castps256_ps128( dot0123 ), _mm256_extractf128_ps( dot0123, 1 ) );
        // Add 4 values into 2
        const __m128 r2 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
        // Add 2 lower values into the final result
        const __m128 r1 = _mm_add_ss( r2, _mm_movehdup_ps( r2 ) );
        // Return the lowest lane of the result vector.
        // The intrinsic below compiles into noop, modern compilers return floats in the lowest lane of xmm0 register.
        return _mm_cvtss_f32( r1 );
    }
    

    Possible further improvements:

    1. Unroll by 8 vectors instead of 4. I’ve checked gcc 9.2 asm output, compiler only used 8 vector registers out of the 16 available.

    2. Make sure both input vectors are aligned, e.g. use a custom allocator which calls _aligned_malloc / _aligned_free on msvc, or aligned_alloc / free on gcc & clang. Then replace _mm256_loadu_ps with _mm256_load_ps.


    To auto-vectorize a simple scalar dot product, you'd also need OpenMP SIMD or -ffast-math (implied by -Ofast) to let the compiler treat FP math as associative even though it's not (because of rounding). But GCC won't use multiple accumulators when auto-vectorizing, even if it does unroll, so you'd bottleneck on FMA latency, not load throughput.

    (2 loads per FMA means the throughput bottleneck for this code is vector loads, not actual FMA operations.)


    Update 2023: because this answer collected many upvotes, here’s another version which supports vectors of arbitrary lengths, not necessarily a multiple of 32 elements. The main loop is the same, the difference is handling of the remainder.

    As you see, it’s relatively tricky to handle the remainder in a way which is both performant, and fair with regards to summation order. Summation order affects numerical precision of the result. The key part of the implementation is _mm256_maskload_ps conditional load instruction.

    #include <immintrin.h>
    #include <vector>
    #include <algorithm>
    #include <assert.h>
    #include <stdint.h>
    
    // CPUs support RAM access like this: "ymmword ptr [rax+64]"
    // Using templates with offset int argument to make easier for compiler to emit good code.
    
    // Returns acc + ( p1 * p2 ), for 8 float lanes
    template<int offsetRegs>
    inline __m256 fma8( __m256 acc, const float* p1, const float* p2 )
    {
        constexpr ptrdiff_t lanes = offsetRegs * 8;
        const __m256 a = _mm256_loadu_ps( p1 + lanes );
        const __m256 b = _mm256_loadu_ps( p2 + lanes );
        return _mm256_fmadd_ps( a, b, acc );
    }
    
    #ifdef __AVX2__
    inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
    {
        // Make a mask of 8 bytes
        // These aren't branches, they should compile to conditional moves
        missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
        uint64_t mask = -( missingLanes < 8 );
        mask >>= missingLanes * 8;
        // Sign extend the bytes into int32 lanes in AVX vector
        __m128i tmp = _mm_cvtsi64_si128( (int64_t)mask );
        return _mm256_cvtepi8_epi32( tmp );
    }
    #else
    // Aligned by 64 bytes
    // The load will only touch a single cache line, no penalty for unaligned load
    static const int alignas( 64 ) s_remainderLoadMask[ 16 ] = {
        -1, -1, -1, -1, -1, -1, -1, -1, 0, 0, 0, 0, 0, 0, 0, 0 };
    inline __m256i makeRemainderMask( ptrdiff_t missingLanes )
    {
        // These aren't branches, they compile to conditional moves
        missingLanes = std::max( missingLanes, (ptrdiff_t)0 );
        missingLanes = std::min( missingLanes, (ptrdiff_t)8 );
        // Unaligned load from a constant array
        const int* rsi = &s_remainderLoadMask[ missingLanes ];
        return _mm256_loadu_si256( ( const __m256i* )rsi );
    }
    #endif
    
    // Same as fma8(), load conditionally using the mask
    // When the mask has all bits set, an equivalent of fma8(), but 1 instruction longer
    // When the mask is a zero vector, the function won't load anything, will return `acc`
    template<int offsetRegs>
    inline __m256 fma8rem( __m256 acc, const float* p1, const float* p2, ptrdiff_t rem )
    {
        constexpr ptrdiff_t lanes = offsetRegs * 8;
        // Generate the mask for conditional loads
        // The implementation depends on whether AVX2 is enabled with compiler switches
        const __m256i mask = makeRemainderMask( ( 8 + lanes ) - rem );
        // These conditional load instructions produce zeros for the masked out lanes
        const __m256 a = _mm256_maskload_ps( p1 + lanes, mask );
        const __m256 b = _mm256_maskload_ps( p2 + lanes, mask );
        return _mm256_fmadd_ps( a, b, acc );
    }
    
    // Compute dot product of float vectors, using 8-wide FMA instructions
    float dotProductFma( const std::vector<float>& a, const std::vector<float>& b )
    {
        assert( a.size() == b.size() );
        const size_t length = a.size();
        if( length == 0 )
            return 0.0f;
    
        const float* p1 = a.data();
        const float* p2 = b.data();
        // Compute length of the remainder; 
        // We want a remainder of length [ 1 .. 32 ] instead of [ 0 .. 31 ]
        const ptrdiff_t rem = ( ( length - 1 ) % 32 ) + 1;
        const float* const p1End = p1 + length - rem;
    
        // Initialize accumulators with zeros
        __m256 dot0 = _mm256_setzero_ps();
        __m256 dot1 = _mm256_setzero_ps();
        __m256 dot2 = _mm256_setzero_ps();
        __m256 dot3 = _mm256_setzero_ps();
    
        // Process the majority of the data.
        // The code uses FMA instructions to multiply + accumulate, consuming 32 values per loop iteration.
        // Unrolling manually for 2 reasons:
        // 1. To reduce data dependencies. With a single register, every loop iteration would depend on the previous result.
        // 2. Unrolled code checks for exit condition 4x less often, therefore more CPU cycles spent computing useful stuff.
        while( p1 < p1End )
        {
            dot0 = fma8<0>( dot0, p1, p2 );
            dot1 = fma8<1>( dot1, p1, p2 );
            dot2 = fma8<2>( dot2, p1, p2 );
            dot3 = fma8<3>( dot3, p1, p2 );
            p1 += 32;
            p2 += 32;
        }
    
        // Handle the last, possibly incomplete batch of length [ 1 .. 32 ]
        // To save multiple branches, we load that entire batch with `vmaskmovps` conditional loads
        // On modern CPUs, the performance of such loads is pretty close to normal full vector loads
        dot0 = fma8rem<0>( dot0, p1, p2, rem );
        dot1 = fma8rem<1>( dot1, p1, p2, rem );
        dot2 = fma8rem<2>( dot2, p1, p2, rem );
        dot3 = fma8rem<3>( dot3, p1, p2, rem );
    
        // Add 32 values into 8
        dot0 = _mm256_add_ps( dot0, dot2 );
        dot1 = _mm256_add_ps( dot1, dot3 );
        dot0 = _mm256_add_ps( dot0, dot1 );
        // Add 8 values into 4
        __m128 r4 = _mm_add_ps( _mm256_castps256_ps128( dot0 ),
            _mm256_extractf128_ps( dot0, 1 ) );
        // Add 4 values into 2
        r4 = _mm_add_ps( r4, _mm_movehl_ps( r4, r4 ) );
        // Add 2 lower values into the scalar result
        r4 = _mm_add_ss( r4, _mm_movehdup_ps( r4 ) );
    
        // Return the lowest lane of the result vector.
        // The intrinsic below compiles into noop, modern compilers return floats in the lowest lane of xmm0 register.
        return _mm_cvtss_f32( r4 );
    }