Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Symbolic differentiation using expression templates in C++

How to implement symbolic differentiation using expression templates in C++

like image 471
coderboy Avatar asked May 10 '12 02:05

coderboy


1 Answers

In general you'd want a way to represent your symbols (i.e. the expressions templates that encode e.g. 3 * x * x + 42), and a meta-function that can compute a derivative. Hopefully you're familiar enough with metaprogramming in C++ to know what that means and entails but to give you an idea:

// This should come from the expression templates
template<typename Lhs, typename Rhs>
struct plus_node;

// Metafunction that computes a derivative
template<typename T>
struct derivative;

// derivative<foo>::type is the result of computing the derivative of foo

// Derivative of lhs + rhs
template<typename Lhs, typename Rhs>
struct derivative<plus_node<Lhs, Rhs> > {
    typedef plus_node<
        typename derivative<Lhs>::type
        , typename derivative<Rhs>::type
    > type;
};

// and so on

You'd then tie up the two parts (representation and computation) such that it would be convenient to use. E.g. derivative(3 * x * x + 42)(6) could mean 'compute the derivative of 3 * x * x + 42 in x at 6'.

However even if you do know what it takes to write expression templates and what it takes to write a metaprogram in C++ I wouldn't recommend going about it this way. Template metaprogramming requires a lot of boilerplate and can be tedious. Instead, I direct you to the genius Boost.Proto library, which is precisely designed to help write EDSLs (using expression templates) and operate on those expression templates. It it not necessarily easy to learn to use but I've found that learning how to achieve the same thing without using it is harder. Here's a sample program that can in fact understand and compute derivative(3 * x * x + 42)(6):

#include <iostream>

#include <boost/proto/proto.hpp>

using namespace boost::proto;

// Assuming derivative of one variable, the 'unknown'
struct unknown {};

// Boost.Proto calls this the expression wrapper
// elements of the EDSL will have this type
template<typename Expr>
struct expression;

// Boost.Proto calls this the domain
struct derived_domain
: domain<generator<expression>> {};

// We will use a context to evaluate expression templates
struct evaluation_context: callable_context<evaluation_context const> {
    double value;

    explicit evaluation_context(double value)
        : value(value)
    {}

    typedef double result_type;

    double operator()(tag::terminal, unknown) const
    { return value; }
};
// And now we can do:
// evalutation_context context(42);
// eval(expr, context);
// to evaluate an expression as though the unknown had value 42

template<typename Expr>
struct expression: extends<Expr, expression<Expr>, derived_domain> {
    typedef extends<Expr, expression<Expr>, derived_domain> base_type;

    expression(Expr const& expr = Expr())
        : base_type(expr)
    {}

    typedef double result_type;

    // We spare ourselves the need to write eval(expr, context)
    // Instead, expr(42) is available
    double operator()(double d) const
    {
        evaluation_context context(d);
        return eval(*this, context);
    }
};

// Boost.Proto calls this a transform -- we use this to operate
// on the expression templates
struct Derivative
: or_<
    when<
        terminal<unknown>
        , boost::mpl::int_<1>()
    >
    , when<
        terminal<_>
        , boost::mpl::int_<0>()
    >
    , when<
        plus<Derivative, Derivative>
        , _make_plus(Derivative(_left), Derivative(_right))
    >
    , when<
        multiplies<Derivative, Derivative>
        , _make_plus(
            _make_multiplies(Derivative(_left), _right)
            , _make_multiplies(_left, Derivative(_right))
        )
    >
    , otherwise<_>
> {};

// x is the unknown
expression<terminal<unknown>::type> const x;

// A transform works as a functor
Derivative const derivative;

int
main()
{
    double d = derivative(3 * x * x + 3)(6);
    std::cout << d << '\n';
}
like image 50
Luc Danton Avatar answered Nov 20 '22 02:11

Luc Danton