c++c++11crtpperfect-forwardingexpression-templates

Nesting of subexpressions in expression templates


We are writing an expression template library to handle operations on values with a sparse gradient vector (first order automatic differentiation). I am trying to figure out how to make it possible to nest sub-expressions by reference or values depending on whether the expressions are temporaries or not.

We have a class Scalar which is containing a value and a sparse gradient vector. We use expression templates (like Eigen) to prevent the construction and allocation of too many temporaries Scalar objects. Thus we have the class Scalar inheriting from ScalarBase<Scalar> (CRTP).

A binary operation (eg +, *) between objects of type ScalarBase< Left > and ScalarBase< Right > return a ScalarBinaryOp<Left, Right,BinaryOp> object which inherits from ScalarBase< ScalarBinaryOp<Left, Right,BinaryOp> >:

template< typename Left, typename Right >
ScalarBinaryOp< Left, Right, BinaryAdditionOp > operator+(
    const ScalarBase< Left >& left, const ScalarBase< Right >& right )
{
    return ScalarBinaryOp< Left, Right, BinaryAdditionOp >( static_cast< const Left& >( left ),
        static_cast< const Right& >( right ), BinaryAdditionOp{} );
}

ScalarBinaryOp must hold a value or reference to the operands objects of type Left and Right. The type of the holder is defined by template specialization of RefTypeSelector< Expression >::Type.

Currently this is always a const reference. It works at the moment for our test cases but this does not seem correct or safe to hold a reference to temporary subexpressions.

Obviously we also do not want that a Scalar object containing the sparse gradient vector be copied. If x and y are Scalar, the expression x+y should hold const reference to x and y. However if f is a function from Scalar to Scalar, x+f(y) should hold a const reference to x and the value of f(y).

Hence I would like to pass the information about whether subexpressions are temporaries or not. I can add this to the expression type parameters:

ScalarBinaryOp< typename Left, typename Right, typename BinaryOp , bool LeftIsTemporary, bool RightIsTemporary >

and to the RefTypeSelector:

RefTypeSelector< Expression, ExpressionIsTemporary >::Type

But then I would need to define for every binary operators 4 methods:

ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, false > operator+(
    const ScalarBase< Left >& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, false, true > operator+(
    const ScalarBase< Left >& left, ScalarBase< Right >&& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, false > operator+(
    ScalarBase< Left >&& left, const ScalarBase< Right >& right );
ScalarBinaryOp< Left, Right, BinaryAdditionOp, true, true > operator+(
    ScalarBase< Left >&& left, ScalarBase< Right >&& right )

I would prefer to be able to achieve this with perfect forwarding. However I do not know how I can achieve this here. First I cannot use simple "universal references" because they match almost anything. I guess it might be possible to combine universal references and SFINAE to only allow certain parameter types but I am not sure this is the way to go. Also I would like to know if I could encode the information about whether Left and Right were originally lvalue or rvalue references in the types Left and Right which parameterize the ScalarBinaryOp instead of using the 2 additional booleans parameter and how to retrieve that information.

I have to support gcc 4.8.5 which is mostly c++11 compliant.

update 2019/08/15: implementation

template < typename Expr >
class RefTypeSelector
{
   private:
   using Expr1 = typename std::decay<Expr>::type;
   public:
   using Type = typename std::conditional<std::is_lvalue_reference<Expr>::value, const Expr1&,Expr1>::type;
};
template< typename Left, typename Right, typename Op >
class ScalarBinaryOp : public ScalarBase< ScalarBinaryOp< Left, Right, Op > >
{

public:

    template <typename L, typename R>
    ScalarBinaryOp( L&& left, R&& right, const Op& op )
        : left_( std::forward<L>(left) )
        , right_( std::forward<R>(right) ))
        , ...
    {
    ...
    }

    ...

private:
    /** LHS expression */
    typename RefTypeSelector< Left >::Type left_;

    /** RHS expression */
    typename RefTypeSelector< Right >::Type right_;

...
}   

template< typename Left, typename Right,
typename Left1 = typename std::decay<Left>::type,
typename Right1 = typename std::decay<Right>::type,
typename std::enable_if<std::is_base_of<ScalarBase<Left1>, Left1>::value,int>::type = 0,
typename std::enable_if<std::is_base_of<ScalarBase<Right1>, Right1>::value,int>::type = 0 >
ScalarBinaryOp< Left, Right, BinaryAdditionOp > operator+(
    Left&& left, Right&& right )
{

    return ScalarBinaryOp< Left, Right, BinaryAdditionOp >( std::forward<Left>( left ),
        std::forward<Right>( right ), BinaryAdditionOp{} );
}


Solution

  • You can encode lvalue/rvalue information into Left and Right types. For example:

    ScalarBinaryOp<Left&&, Right&&> operator+(
        ScalarBase<Left>&& left, ScalarBase<Right>&& right)
    {
        return ...;
    }
    

    with ScalarBinaryOp being something like this:

    template<class L, class R>
    struct ScalarBinaryOp
    {
        using Left = std::remove_reference_t<L>;
        using Right = std::remove_reference_t<R>;
    
        using My_left = std::conditional_t<
            std::is_rvalue_reference_v<L>, Left, const Left&>;
        using My_right = std::conditional_t<
            std::is_rvalue_reference_v<R>, Left, const Right&>;
    
        ...
    
        My_left left_;
        My_right right_;
    };
    

    Alternatively, you can be explicit and store everything by value, except for Scalars. To be able to store a Scalar by value, you use a wrapper class:

    x + Value_wrapper(f(y))
    

    The wrapper is simple:

    struct Value_wrapper : Base<Value_wrapper> {
        Value_wrapper(Scalar&& scalar) : scalar_(std::move(scalar)) {}
    
        operator Scalar() const {
            return std::move(scalar_);
        }
    
        Scalar&& scalar_;
    };
    

    RefTypeSelector has specialization for Value_wrapper:

    template<> struct RefTypeSelector<Value_wrapper> {
        using Type = Scalar;
    };
    

    The binary operator definition remains the same:

    template<class Left, class Right>
    ScalarBinaryOp<Left, Right> operator+(const Base<Left>& left, const Base<Right>& right) {
        return {static_cast<const Left&>(left), static_cast<const Right&>(right)};
    }
    

    Complete example: https://godbolt.org/z/sJ3NfG

    (I used some C++17 features above only to simplify the notation.)