c++optimizationintelsimdsse

Is worth using SSE or should I just rely on the compiler?


I am looking into SSE instructions which are great and started to work some simple code to measure the difference between a function using them and the same function using "standard" code (i.e non SSE). I realised that when I compiled the code (with the -O3 flag), the version using the SSE version of the function is actually (very slightly) "slower" than the version of the program which is NOT using SSE instructions. My guess is that:

  1. the compiler does an excellent job at optimising the code
  2. the SSE function could run faster but there's a cost to loading the floats to the registers which cancels out the benefit from using the SSE instructions.
  3. the testSSE() function is not complex enough to really show a difference between a version of the program using SSE and one that doesn't.

Could anyone tell me what his/her thoughts are on this? Thanks a lot -

EDIT: so I corrected the code (see below the 2 code listings). Even with the corrected version which is shorter, the SSE version gives me 2''48 while the non-SSE version gives me 1''36, confirming the fact that, in that case the compiler does a better job than me!

EDIT: OLD CODE WITH BUG (see below correction version)

// compiled with c++ tmp.cpp -msse4 -o testSSE -O3

#include <iostream>
#include <cmath>

#include <stdio.h>
#include <pmmintrin.h>

inline void testSSE(float *node1, float *node2, float *node3, float *node4, float *result)
{
    __m128 tmp0, tmp1, tmp2, tmp3;
    __m128 l, r;

    l = _mm_load_ps(node1);         //_mm_store_ps(result, l); fprintf(stderr, "1 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
    r = _mm_load_ps(node1 + 4);     //_mm_store_ps(result, r); fprintf(stderr, "2 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
    tmp0 = _mm_hadd_ps(l, r);       //_mm_store_ps(result, tmp0); fprintf(stderr, "3 %f %f %f %f\n", result[0], result[1], result[2], result[3]);

    l = _mm_load_ps(node2);         //_mm_store_ps(result, l); fprintf(stderr, "4 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
    r = _mm_load_ps(node2 + 4);     //_mm_store_ps(result, r); fprintf(stderr, "5 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
    tmp1 = _mm_hadd_ps(l, r);       //_mm_store_ps(result, tmp0); fprintf(stderr, "6 %f %f %f %f\n", result[0], result[1], result[2], result[3]);

    l = _mm_load_ps(node3);
    r = _mm_load_ps(node3 + 4);
    tmp2 = _mm_hadd_ps(l, r);

    l = _mm_load_ps(node4);         //_mm_store_ps(result, l); fprintf(stderr, "10 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
    r = _mm_load_ps(node4 + 4);     //_mm_store_ps(result, r); fprintf(stderr, "11 %f %f %f %f\n", result[0], result[1], result[2], result[3]);
    tmp3 = _mm_hadd_ps(l, r);       //_mm_store_ps(result, tmp0); fprintf(stderr, "12 %f %f %f %f\n", result[0], result[1], result[2], result[3]);

    l = _mm_hadd_ps(tmp0, tmp1);
    r = _mm_hadd_ps(tmp2, tmp3);

    __m128 pDest = _mm_hadd_ps(l, r);

    _mm_store_ps(result, pDest);    // fprintf(stderr, "FINAL %f %f %f %f\n", result[0], result[1], result[2], result[3]);
}

void test(float *node1, float *node2, float *node3, float *node4, float *result)
{
    float tmp0[4], tmp1[4], tmp2[4], tmp3[4];
    tmp0[0] = node1[0] + node1[1];
    tmp0[1] = node1[2] + node1[3];
    tmp0[2] = node1[4] + node1[5];
    tmp0[3] = node1[6] + node1[7];

    tmp1[0] = node2[0] + node2[1];
    tmp1[1] = node2[2] + node2[3];
    tmp1[2] = node2[4] + node2[5];
    tmp1[3] = node2[6] + node2[7];

    tmp2[0] = node3[0] + node3[1];
    tmp2[1] = node3[2] + node3[3];
    tmp2[2] = node3[4] + node3[5];
    tmp2[3] = node3[6] + node3[7];

    tmp3[0] = node4[0] + node4[1];
    tmp3[1] = node4[2] + node4[3];
    tmp3[2] = node4[4] + node4[5];
    tmp3[3] = node4[6] + node4[7];

    float l[4], r[4];
    l[0] = tmp0[0] + tmp0[1];
    l[1] = tmp0[2] + tmp0[3];
    l[2] = tmp1[0] + tmp1[1];
    l[3] = tmp1[2] + tmp1[3];

    r[0] = tmp2[0] + tmp2[1];
    r[1] = tmp2[2] + tmp2[3];
    r[2] = tmp3[0] + tmp3[1];
    r[3] = tmp3[2] + tmp3[3];

    result[0] = l[0] + l[1];
    result[1] = l[2] + l[3];
    result[2] = r[0] + r[1];
    result[3] = r[2] + r[3];

}

int main(int argc, char **argv)
{
    int nnodes = 4;
    double t = clock();
    for (int k = 0; k < 10000000; ++k) {
        float *data = new float [nnodes * 8];
        for (int i = 0; i < nnodes * 8; ++i) { data[i] = (i / 8) + 1; /* fprintf(stderr, "data %02d %f\n", i, data[i]); */ }
        float result[4];
        int off = sizeof(float) * 8;
        testSSE(data, data + 8, data + 16, data + 24, result);
        delete [] data;
    }
    fprintf(stderr, "%02f (sec)\n", (clock() - t) / (float)CLOCKS_PER_SEC);
    return 0;
}

EDIT: new (corrected) code

#include <iostream>
#include <cmath>

#include <stdio.h>
#include <pmmintrin.h>

inline void testSSE(float *node1, float *node2, float *node3, float *node4, float *result)
{
    __m128 tmp0, tmp1, tmp2, tmp3;

    tmp0 = _mm_load_ps(node1);
    tmp1 = _mm_load_ps(node2);
    tmp2 = _mm_hadd_ps(tmp0, tmp1);

    tmp0 = _mm_load_ps(node3);
    tmp1 = _mm_load_ps(node4);
    tmp3 = _mm_hadd_ps(tmp0, tmp1);

    tmp0 = _mm_hadd_ps(tmp2, tmp3);

    _mm_store_ps(result, tmp0);
}

void test(float *node1, float *node2, float *node3, float *node4, float *result)
{
    float tmp0[4], tmp1[4], tmp2[4], tmp3[4];
    tmp0[0] = node1[0] + node1[1];
    tmp0[1] = node1[2] + node1[3];
    tmp0[2] = node1[4] + node1[5];
    tmp0[3] = node1[6] + node1[7];

    tmp1[0] = node2[0] + node2[1];
    tmp1[1] = node2[2] + node2[3];
    tmp1[2] = node2[4] + node2[5];
    tmp1[3] = node2[6] + node2[7];

    tmp2[0] = node3[0] + node3[1];
    tmp2[1] = node3[2] + node3[3];
    tmp2[2] = node3[4] + node3[5];
    tmp2[3] = node3[6] + node3[7];

    tmp3[0] = node4[0] + node4[1];
    tmp3[1] = node4[2] + node4[3];
    tmp3[2] = node4[4] + node4[5];
    tmp3[3] = node4[6] + node4[7];

    float l[4], r[4];
    l[0] = tmp0[0] + tmp0[1];
    l[1] = tmp0[2] + tmp0[3];
    l[2] = tmp1[0] + tmp1[1];
    l[3] = tmp1[2] + tmp1[3];

    r[0] = tmp2[0] + tmp2[1];
    r[1] = tmp2[2] + tmp2[3];
    r[2] = tmp3[0] + tmp3[1];
    r[3] = tmp3[2] + tmp3[3];

    result[0] = l[0] + l[1];
    result[1] = l[2] + l[3];
    result[2] = r[0] + r[1];
    result[3] = r[2] + r[3];
}

int main(int argc, char **argv)
{

    int nnodes = 4;
    float *data = new float [nnodes * 8];
    for (int i = 0; i < nnodes * 8; ++i) { data[i] = (i / 8) + 1; /* fprintf(stderr, "data %02d %f\n", i, data[i]); */ }
    double t = clock();
    for (int k = 0; k < 1e+9; ++k) {
        float result[4];
        int off = sizeof(float) * 8;
        test(data, data + 8, data + 16, data + 24, result);
    }
    fprintf(stderr, "%02f (sec)\n", (clock() - t) / (float)CLOCKS_PER_SEC);
            delete [] data;
    return 0;
}

Solution

  • I fixed your code to use SIMD efficiently. Your old method gets 14.1 seconds on my computer and then new method takes 1.2 seconds. I rewrote the code in your test function to make it simpler to read but otherwise it's the same.

    The old method stored the nodes in memory like this: node1[0], node1[1],...node1[7], node2[0], node2[1],.... The way you have now is called an Array of Structs (AoS). That's the slow way to use SSE and that's why it's not any better than your scalar code.

    The new method which uses SSE store the nodes like this: node1[0], node2[0], node3[0], node4[0], node1[1], node2[1], .... This is called a Struct of Arrays (SoA). That's the efficient way to use SIMD. In general if you're using hadd often (or the dot product instruction) then you probably not using the best algorithm with SIMD.

    Here is the code including your old method and my new one. Note, there are several additional ways you could try to make this more efficient, such as unrolling the loop, but now at least the SIMD is being used correctly.

    #include <iostream>
    #include <cmath>
    
    #include <stdio.h>
    #include <pmmintrin.h>
    
    void test(float *node1, float *node2, float *node3, float *node4, float *result)
    {
        result[0] = node1[0] + node1[1] + node1[2] + node1[3] + node1[4] + node1[5] + node1[6] + node1[7];
        result[1] = node2[0] + node2[1] + node2[2] + node2[3] + node2[4] + node2[5] + node2[6] + node2[7];
        result[2] = node3[0] + node3[1] + node3[2] + node3[3] + node3[4] + node3[5] + node3[6] + node3[7];
        result[3] = node4[0] + node4[1] + node4[2] + node4[3] + node4[4] + node4[5] + node4[6] + node4[7];
    }
    
    void testSSE(float *nodes_soa, float *result)
    {
      __m128 sum = _mm_set1_ps(0.0f);
      for(int i=0; i<8; i++) {
        __m128 tmp0 = _mm_load_ps(nodes_soa + 4*i);
        sum =_mm_add_ps(tmp0, sum);      
      }
      _mm_store_ps(result, sum);
    }
    int main(int argc, char **argv)
    {
    
        int nnodes = 4;
        float *data = new float [nnodes * 8];
        double t;
    
        //old method using array of structs (AoS)
        for (int i = 0; i < nnodes * 8; ++i) { 
          data[i] = (i / 8) + 1; 
        //  printf("data %02d %f\n", i, data[i]); 
        }
    
        t = clock();
        for (int k = 0; k < 1e+9; ++k) {
            float result[4];
            int off = sizeof(float) * 8;
            test(data, data + 8, data + 16, data + 24, result);
        //printf("%f %f %f %f\n", result[0], result[1], result[2], result[3]);
        }
        printf("%02f (sec)\n", (clock() - t) / (float)CLOCKS_PER_SEC);
    
        //new method using struct of arrays (SoA)
        for (int i = 0; i < nnodes * 8; ++i) { 
          data[i] = i%4 + 1; 
          //printf("data %02d %f\n", i, data[i]); 
        }
    
        t = clock();
        for (int k = 0; k < 1e+9; ++k) {
            float result[4];
            int off = sizeof(float) * 8;
            //test(data, data + 8, data + 16, data + 24, result);
            testSSE(data, result);
        //printf("%f %f %f %f\n", result[0], result[1], result[2], result[3]);
        }
        printf("%02f (sec)\n", (clock() - t) / (float)CLOCKS_PER_SEC);
    
        delete [] data;
        return 0;
    } 
    

    Edit: In general you want to use 16 bit alignment in SSE. Here are the functions I normally used.

    inline void* aligned_malloc(size_t size, size_t align) {
        void *result;
        #ifdef _MSC_VER 
        result = _aligned_malloc(size, align);
        #else 
         if(posix_memalign(&result, align, size)) result = 0;
        #endif
        return result;
    }
    
    inline void aligned_free(void *ptr) {
        #ifdef _MSC_VER 
            _aligned_free(ptr);
        #else 
          free(ptr);
        #endif
    
    }
    

    Use

    //float *data = new float [nnodes * 8];
    float *data = (float*) aligned_malloc(nnodes*8*sizeof(float), 16);