Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way to make this C++14 recursive template shorter in C++17?

Tags:

c++

c++17

c++14

This poly_eval function will compute the result of evaluating a polynomial with a particular set of coefficients at a particular value of x. For example, poly_eval(5, 1, -2, -1) computes x^2 - 2x - 1 with x = 5. It's all constexpr so if you give it constants it will compute the answer at compile time.

It currently uses recursive templates to build the polynomial evaluation expression at compile time and relies on C++14 to be constexpr. I was wondering if anybody could think of a good way to remove the recursive template, perhaps using C++17. The code that exercises the template uses the __uint128_t type from clang and gcc.

#include <type_traits>
#include <tuple>

template <typename X_t, typename Coeff_1_T>
constexpr auto poly_eval_accum(const X_t &x, const Coeff_1_T &c1)
{
    return ::std::pair<X_t, Coeff_1_T>(x, c1);
}

template <typename X_t, typename Coeff_1_T, typename... Coeff_TList>
constexpr auto poly_eval_accum(const X_t &x, const Coeff_1_T &c1, const Coeff_TList &... coeffs)
{
    const auto &tmp_result = poly_eval_accum(x, coeffs...);
    auto saved = tmp_result.second + tmp_result.first * c1;
    return ::std::pair<X_t, decltype(saved)>(tmp_result.first * x, saved);
}

template <typename X_t, typename... Coeff_TList>
constexpr auto poly_eval(const X_t &x, const Coeff_TList &... coeffs)
{
    static_assert(sizeof...(coeffs) > 0,
                  "Must have at least one coefficient.");
    return poly_eval_accum(x, coeffs...).second;
}

// This is just a test function to exercise the template.
__uint128_t multiply_lots(__uint128_t num, __uint128_t n2)
{
    const __uint128_t cf = 5;
    return poly_eval(cf, num, n2, 10);
}

// This is just a test function to exercise the template to make sure
// it computes the result at compile time.
__uint128_t eval_const()
{
    return poly_eval(5, 1, -2, 1);
}

Also, am I doing anything wrong here?

-------- Comments on Answers --------

There are two excellent answers down below. One is clear and terse, but may not handle certain situations involving complex types (expression trees, matrices, etc..) well, though it does a fair job. It also relies on the somewhat obscure , operator.

The other is less terse, but still much clearer than my original recursive template, and it handles types just as well. It expands out to 'cn + x * (cn-1 + x * (cn-2 ...' whereas my recursive version expands out to cn + x * cn-1 + x * x * cn-2 .... For most reasonable types they should be equivalent, and the answer can easily be modified to expand out to what my recursive one expands to.

I picked the first answer because it was 1st and its terseness is more within the spirit of my original question. But, if I were to choose a version for production, I'd choose the second.

like image 802
Omnifarious Avatar asked Nov 23 '17 00:11

Omnifarious


2 Answers

Using the power of comma operator (and C++17 folding, obviously), I suppose you can write poly_eval() as follows

template <typename X_t, typename C_t, typename ... Cs_t>
constexpr auto poly_eval (X_t const & x, C_t a, Cs_t const & ... cs)
 {
   ( (a *= x, a += cs), ..., (void)0 );

   return a;
 }

trowing away poly_eval_accum().

Observe that the first coefficient if explicated, so you can delete also the static_assert() and is passed by copy, and become the accumulator.

-- EDIT --

Added an alternative version to solve the problem of the return type using std::common_type a decltype() of an expression, as the OP suggested; in this version a is a constant reference again.

template <typename X_t, typename C_t, typename ... Cs_t>
constexpr auto poly_eval (X_t const & x, C_t const & c1, Cs_t const & ... cs)
 {
   decltype(((x * c1) + ... + (x * cs))) ret { c1 };

   ( (ret *= x, ret += cs), ..., (void)0 );

   return ret;
 }

-- EDIT 2 --

Bonus answer: it's possible avoid the recursion also in C++14 using the power of the comma operator (again) and initializing an unused C-style array of integers

template <typename X_t, typename C_t, typename ... Cs_t>
constexpr auto poly_eval (X_t const & x, C_t const & a, Cs_t const & ... cs)
 {
   using unused = int[];

   std::common_type_t<decltype(x * a), decltype(x * cs)...>  ret { a };

   (void)unused { 0, (ret *= x, ret += cs)... };

   return ret;
 }
like image 164
max66 Avatar answered Oct 07 '22 22:10

max66


A great answer is supplied above, but it requires a common return type and will therefore not work if you are, say, building a compile time expression tree.

What we need is some way to have a fold expression that both does the multiply with the value at the evaluation point x and add a coefficient at each iteration, in order to eventually end up with an expression like: (((c0) * x + c1) * x + c2) * x + c3. This is (I think) not possible with a fold expression directly, but we can define a special type that overloads a binary operator and does the necessary calculations.

template<class M, class T>
struct MultiplyAdder
{
    M mul;
    T acc;
    constexpr MultiplyAdder(M m, T a) : mul(m), acc(a) { }
};

template<class M, class T, class U>
constexpr auto operator<<(const MultiplyAdder<M,T>& ma, const U& u)
{
    return MultiplyAdder(ma.mul, ma.acc * ma.mul + u);
}

template <typename X_t, typename C_t, typename... Coeff_TList>
constexpr auto poly_eval(const X_t &x, const C_t &a, const Coeff_TList &... coeffs)
{
    return (MultiplyAdder(x, a) << ... << coeffs).acc;
}

As a bonus, this solution also ticks C++17's 'automatic class template argument deduction' box ;)

Edit: Oops, argument deduction wasn't working inside MultiplyAdder<>::operator<<(), because MultiplyAdder refers to its own template-id rather than its template-name. I've added a namespace specifier, but that unfortunately makes it dependent on its own namespace. There must be a way to refer to its actual template-name, but I can't think of any without resorting to template aliases.

Edit2: Fixed it by making operator<<() a non-member.

like image 30
oisyn Avatar answered Oct 07 '22 20:10

oisyn