c++templatesexpression-templates

How to use expression-templates for specific types?


When using expression templates, how do I create specializations? From the Wikipedia example, I can make a Vector sum template class like so:

template <typename E1, typename E2>
class VecSum : public VecExpression<VecSum<E1, E2> > {
    E1 const& _u;
    E2 const& _v;
public:
    VecSum(E1 const& u, E2 const& v) : _u(u), _v(v) {
        assert(u.size() == v.size());
    }
    double operator[](size_t i) const { return _u[i] + _v[i]; }
    size_t size()               const { return _v.size(); }
};

template <typename E1, typename E2>
VecSum<E1,E2> const
operator+(E1 const& u, E2 const& v) {
   return VecSum<E1, E2>(u, v);
}

According to Wikipedia, if I have a Vector class that extends VecExpression<Vector> and a constructor from the VecExpression class that uses the [] operator and a loop, this will allow loop merging so statements like the following only use a single loop:

Vector a = ...;
Vector b = ...;
Vector c = ...;
Vector d = a+b+c;

I get why this works, but I'm not sure how to extend it to scalars. I want to be able to add a scalar (int, float, or double) to the entire Vector, but I'm not sure how to do this. My best guess is to create specializations for the VecSum class like:

template<typename E2> VecSum<int, E2>{ /*stuff goes here*/ }
template<typename E1> VecSum<E1, int>{ /*stuff goes here*/ }
template<typename E2> VecSum<float, E2>{ /*stuff goes here*/ }
template<typename E1> VecSum<E1, float>{ /*stuff goes here*/ }
template<typename E2> VecSum<double, E2>{ /*stuff goes here*/ }
template<typename E1> VecSum<E1, double>{ /*stuff goes here*/ }

But this seems like its a lot more work than is necessary, is there another solution?


Solution

  • Simply use SFINAE to check if the type is an arithmethic one and specialize as needed.

    Example:

    template <typename E1, typename E2, typename Enable = void > class VecSum;
    
    template <typename E1, typename E2>
    class VecSum< E1, E2,
          typename std::enable_if_t<!std::is_arithmetic<E1>::value && !std::is_arithmetic<E2>::value>
          > : public VecExpression<VecSum<E1, E2> >
    {
        E1 const& _u;
        E2 const& _v;
    
        public:
    
        VecSum(E1 const& u, E2 const& v) : _u(u), _v(v)
        {
            assert(u.size() == v.size());
        }
    
        double operator[](size_t i) const { return _u[i] + _v[i]; }
        size_t size()               const { return _v.size(); }
    };
    
    template <typename E1, typename E2>
    class VecSum < E1, E2,
          typename std::enable_if_t< std::is_arithmetic<E1>::value && !std::is_arithmetic<E2>::value>
          > : public VecExpression<VecSum<E1, E2> >
    {
        E1 const& _u;
        E2 const& _v;
    
        public:
    
        VecSum(E1 const& u, E2 const& v) : _u(u), _v(v)
        {
        }
    
        double operator[](size_t i) const { return _u + _v[i]; }
        size_t size()               const { return _v.size(); }
    };
    
    
    template <typename E1, typename E2>
    class VecSum < E1, E2,
          typename std::enable_if_t< !std::is_arithmetic<E1>::value && std::is_arithmetic<E2>::value>
          > : public VecExpression<VecSum<E1, E2> >
    {
        E1 const& _u;
        E2 const& _v;
    
        public:
    
        VecSum(E1 const& u, E2 const& v) : _u(u), _v(v)
        {
        }
    
        double operator[](size_t i) const { return _u[i] + _v; }
        size_t size()               const { return _u.size(); }
    };
    
    int main(){
        Vec v0 = { 1, 2, 3 ,4 };
        Vec v1 = {10, 20,30,40 };
        Vec v2 = {100,200,300,400 };
    
        {
    
            Vec sum = v0+v1+v2;
            Vec v3(4);
    
            for(int i=0;i<4;++i)
                v3[i]=sum[i];
    
    
            for(unsigned int i=0;i<v3.size();++i)
                std::cout << v3[i] << std::endl;
        }
    
        std::cout << "with lhs skalar" << std::endl;
    
        {
            Vec sum = 5 + 50 + v1;
            Vec v3(4);
    
            for(int i=0;i<4;++i)
                v3[i]=sum[i];
    
    
            for(unsigned int i=0;i<v3.size();++i)
                std::cout << v3[i] << std::endl;
    
        }
    
        std::cout << "with rhs skalar" << std::endl;
    
        {
            Vec sum = v1 + 5 + 50 ;
            Vec v3(4);
    
            for(int i=0;i<4;++i)
                v3[i]=sum[i];
    
    
            for(unsigned int i=0;i<v3.size();++i)
                std::cout << v3[i] << std::endl;
    
        }
    
    }