c++inheritancetypesexpression-templates

C++ Expression Templates - Why the base class?


I recently stumbled across expression templates in C++. There is one thing that I do not quite understand about their implementation, and that is why the base class is necessary (from which all other objects relating to template expression derive in a CRTP manner). A simple example that does addition and scalar multiplication on vectors (objects of type Vec, without base class):

#include <vector>
#include <iostream>
using namespace std;

class Vec
{
    vector<double> data;
public:

    template<typename E>
    Vec(E expr)
    {
        data = vector<double>(expr.size());
        for (int i = 0; i < expr.size(); i++)
            data[i] = expr[i];
    }
    Vec(int size)
    {
        data = vector<double>(size);
        for (int i = 0; i < size; i++)
            data[i] = i;
    }
    double operator [] (int idx) {
        return data[idx];
    }

    int size() { return data.size(); }

    bool operator < (Vec &rhs)
    {
        return (*this)[0] < rhs[0];
    }

    bool operator > (Vec &rhs)
    {
        return (*this)[0] > rhs[0];
    }

};

template<typename E1, typename E2>
class VecAdd
{
    E1 vec_expr1;
    E2 vec_expr2;

public:
    VecAdd(E1 vec_expr1, E2 vec_expr2) : vec_expr1(vec_expr1), vec_expr2(vec_expr2)
    {}

    double operator [] (int idx) { return vec_expr1[idx] + vec_expr2[idx]; }
    int size() { return vec_expr1.size(); }
};

template<typename E>
class ScalarMult
{
    E vec_expr;
    double scalar;

public:
    ScalarMult(double scalar, E vec_expr) : scalar(scalar), vec_expr(vec_expr)
    {}

    double operator [] (int idx) { return scalar*vec_expr[idx]; }
    int size() { return vec_expr.size(); }
};

template<typename E1, typename E2>
VecAdd<E1, E2> operator + (E1 vec_expr1, E2 vec_expr2)
{
    return VecAdd<E1, E2>(vec_expr1, vec_expr2);
}

template<typename E>
ScalarMult<E> operator * (double scalar, E vec_expr)
{
    return ScalarMult<E>(scalar, vec_expr);
}

int main()
{
    Vec term1(5);
    Vec term2(5);

    Vec result = 6*(term1 + term2);
    Vec result2 = 4 * (term1 + term2 + term1);

    //vector<Vec> vec_vector = {result, result2};     does not compile
    vector<Vec> vec_vector;

    vec_vector = { result2, result };   //compiles

    vec_vector.clear();
    vec_vector.push_back(result);
    vec_vector.push_back(result2);      //all this compiles

    for (int i = 0; i < result.size(); i++)
        cout << result[i] << " ";
    cout << endl;

    system("pause");

    return 0;
}

The code above compiles (except for the indicated line), and it evaluates the simple expressions in the main function without fault. If the expressions get assigned to an object of type Vec and assign their contents to a Vec object, getting destroyed in the process in any case, why is it necessary for a base class? (as shown in this Wikipedia article)

EDIT:

I know this code is a bit messy and bad (copying where unnecessary, etc.) but I am not planning on using this specific code. This is just to illustrate that expression templates work in this example without the CRTP base class - and I am trying to figure out exactly why this base class is necessary.


Solution

  • Your

    template<typename E1, typename E2>
    VecAdd<E1, E2> operator + (E1 vec_expr1, E2 vec_expr2)
    

    will match for any user-defined types, not merely expression types. When instantiated with non-vector types, it will then likely fail. This interacts very badly with other C++ types, quite possibly including standard library types, which provide their own custom operator + and may rely on inexact matches resolving to their own operator + after implicit conversions.

    Making operator + only available for VecExpression<E> avoids that problem.