Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Transforming an expression template tree

Given an expression template tree, I want to create a new optimized tree before processing it. Consider the following example of a multiplication operation:

a * b * c * d,

which produces, due to the left-to-right associativity of operator*, the expression tree:

(((a * b) * c) * d).

I would like to produce a transformed expression tree where multiplication occurs from right-to-left:

(a * (b * (c * d))).

Consider the binary expression type:

template<typename Left, typename Right>
struct BinaryTimesExpr
{
    BinaryTimesExpr() = default;
    BinaryTimesExpr(const BinaryTimesExpr&) = default;
    BinaryTimesExpr(BinaryTimesExpr&&) = default;
    BinaryTimesExpr(Left&& l, Right&& r) : left(forward<Left>(l)), right(forward<Right>(r)) {}

    BinaryTimesExpr& operator=(const BinaryTimesExpr&) = default;
    BinaryTimesExpr& operator=(BinaryTimesExpr&&) = default;

    Left left;
    Right right;
};

Define the multiplication operator operator*:

template<typename Left, typename Right>
BinaryTimesExpr<Constify<Left>, Constify<Right>> operator*(Left&& l, Right&& r)
{
    return {forward<Left>(l), forward<Right>(r)};
}

where Constify is defined by:

template<typename T> struct HelperConstifyRef     { using type = T;  };
template<typename T> struct HelperConstifyRef<T&> { using type = const T&; };
template<typename T>
using ConstifyRef = typename HelperConstifyRef<T>::type;

and used to ensure sub-expressions with const lvalue-references when constructed from lvalues, and copies (by copy/move) of rvalues when constructed from rvalues.

Define the transformation function that creates a new expression template tree with the previous conditions:

template<typename Expr>
auto Transform(const Expr& expr) -> Expr
{
    return expr;
}

template<typename Left, typename Right>
auto Transform(const BinaryTimesExpr<Left, Right>& expr) -> type(???)
{
    return {(Transform(expr.left), Transform(expr.right))};
}

template<typename Left, typename Right>
auto Transform(const BinaryTimesExpr<BinaryTimesExpr<LeftLeft, LeftRight>, Right>& expr) -> type(???)
{
    return Transform({Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}}); // this sintax is invalid...how can I write this?
}

My questions are:

1) How do I determine the return types of the Transform functions? I've tried using type traits like:

template<typename Expr>
struct HelperTransformedExpr
{
    using type = Expr;
};

template<typename Left, typename Right>
struct HelperTransformedExpr<BinaryTimesExpr<Left, Right>>
{
    using type = BinaryTimesExpr<typename HelperTransformedExpr<Left>::type, typename HelperTransformedExpr<Right>::type>;
};

template<typename LeftLeft, typename LeftRight, typename Right>
struct HelperTransformedExpr<BinaryTimesExpr<BinaryTimesExpr<LeftLeft, LeftRight>, Right>>
{
    using type = BinaryTimesExpr<typename HelperTransformedExpr<LeftLeft>::type,
        BinaryTimesExpr<typename HelperTransformedExpr<LeftRight>::type, typename HelperTransformedExpr<Right>::type>>;
};

template<typename Expr>
using TransformedExpr = typename HelperTransformedExpr<Expr>::type;

but don't know how to apply this to solve my question (2) below.

2) How do I write the recursion line:

return Transform({Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}});

3) Is there a cleaner solution for this problem?


Edit: DyP presents a partial solution to the above problem. Below is my full solution based on his answer:

template<typename Expr>
auto Transform(const Expr& expr) -> Expr
{
    return expr;
}

template<typename Left, typename Right>
auto Transform(BinaryTimesExpr<Left, Right> const& expr)
-> decltype(BinaryTimesExpr<decltype(Transform(expr.left)), decltype(Transform(expr.right))>{Transform(expr.left), Transform(expr.right)})
{
    return BinaryTimesExpr<decltype(Transform(expr.left)), decltype(Transform(expr.right))>{Transform(expr.left), Transform(expr.right)};
}

template<typename LeftLeft, typename LeftRight, typename Right>
auto Transform(BinaryTimesExpr<BinaryTimesExpr<LeftLeft, LeftRight>, Right> const& expr)
-> decltype(Transform(BinaryTimesExpr<decltype(Transform(expr.left.left)), BinaryTimesExpr<decltype(Transform(expr.left.right)), decltype(Transform(expr.right))>>{Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}}))
{
    return Transform(BinaryTimesExpr<decltype(Transform(expr.left.left)), BinaryTimesExpr<decltype(Transform(expr.left.right)), decltype(Transform(expr.right))>>{Transform(expr.left.left), {Transform(expr.left.right), Transform(expr.right)}});
}

int main()
{
    BinaryTimesExpr<int, int> beg{1,2};
    auto res = beg*3*4*5*beg;
    std::cout << res << std::endl;
    std::cout << Transform(res) << std::endl;
}

Output:

(((((1*2)*3)*4)*5)*(1*2))
(1*(2*(3*(4*(5*(1*2))))))

Note that it was necessary to apply the Transform function on every sub-expression besides the most external Transform call (see the last Transform overload).

The full source code can be found here.

like image 608
Allan Avatar asked Aug 13 '13 15:08

Allan


1 Answers

Although the OP wanted a solution that didn't use Boost.Proto, others might be interested to see how this could be (pretty easily) done using it:

#include <iostream>
#include <boost/type_traits/is_same.hpp>
#include <boost/proto/proto.hpp>

namespace proto = boost::proto;
using proto::_;

struct empty {};

struct transform
  : proto::reverse_fold_tree<
        _
      , empty()
      , proto::if_<
            boost::is_same<proto::_state, empty>()
          , _
          , proto::_make_multiplies(_, proto::_state)
        >
    >
{};

int main()
{
    proto::literal<int> a(1), b(2), c(3), d(4);
    proto::display_expr( a * b * c * d );
    proto::display_expr( transform()(a * b * c * d) );
}

The above code displays:

multiplies(
    multiplies(
        multiplies(
            terminal(1)
          , terminal(2)
        )
      , terminal(3)
    )
  , terminal(4)
)
multiplies(
    terminal(1)
  , multiplies(
        terminal(2)
      , multiplies(
            terminal(3)
          , terminal(4)
        )
    )
)
like image 83
Eric Niebler Avatar answered Oct 21 '22 01:10

Eric Niebler