Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

C++ metaprogramming - compile time search tree

UPDATE: sorry for confusing terms - I do not need a binary-tree, but segment-tree or interval-tree.

Imagine I have to statically initialize a search tree each time mine program is executed.

Tree t;
t.add(10, 'Apple');
t.add(20, 'Pear');
t.add(50, 'Orange');
...
t.add(300, 'Cucumber');

..
// then I use it.
int key = 15;
String s = t.lookup(key) // Returns 'Apple' ( as the key is between 10 and 20)

The keys and values in a tree are "static", hard-coded, but have to be maintained from time to time. Is there are metaprogramming trick how to organise key values into binary-search tree (or a skip list) during compile time?

For example the whole search tree is implemented directly in code .text and nothing is held in .data? I can also "predict" number of keys and provide order them.

like image 282
ibre5041 Avatar asked Mar 17 '23 23:03

ibre5041


2 Answers

I suspect you are making a mountain out of a molehill here, and that it's because:-

  • You believe that to statically initialize something in C++ you have to do it at compiletime.

  • Either you are not acquainted with the concepts of upper and lower bounds or else you don't know that the {upper|lower} bound of v in a [partially] ordered sequence S can be determined by binary search of S, and that you can count on the Standard Library to do it at least that efficiently.

I think you want to have a statically initialized data structure mapping integer keys to string literals such that, at runtime, you can query it with an integer n and very efficiently retrieve the string literal s (if any), whose key is the largest that is not larger than n - with the additional proviso, presumably, that n is not larger than all keys.

If that is right, then the statically initialized data structure you need is simply a statically initialized map M of integers to string literals. Template meta-programming is not in the frame.

Because of the (presumed) proviso that that a query shall fail for n larger than all keys, you will need to include a sentinel value in M with a key 1 larger than the largest you want to find.

Then, for runtime integer n, you query M for the upper bound of n. The upper bound of n in M is the smallest key larger than n, if any. If the returned iterator it is M.end() then you have no string for n. Otherwise, if it == M.begin(), then every key is greater than n, so again you have no string for n. Otherwise, there must exist a <key,value> located by --it, and that keymust be the largest key that is not larger than n. So your string for n is that value.

#include <map>

static const std::map<int,char const *> tab = 
{
    { 2,"apple" },
    { 5,"pear" },
    { 9,"orange" },
    { 14,"banana" },
    { 20,"plum" },
    { 20 + 1,nullptr }
};

const char * lookup(int n)
{
    auto it = tab.upper_bound(n);
    return it == tab.begin() || it == tab.end() ? nullptr : (--it)->second;
}

Prepend that to this example:

#include <iostream>

using namespace std;

int main(void)
{
    for (int i = 0; i <= 21; ++i) {
        cout << i;
        char const *out = lookup(i);
        cout << " -> " << (!out ? "Not found" : out) << endl;
    }
    return 0;
}

and the output will be:

0 -> Not found
1 -> Not found
2 -> apple
3 -> apple
4 -> apple
5 -> pear
6 -> pear
7 -> pear
8 -> pear
9 -> orange
10 -> orange
11 -> orange
12 -> orange
13 -> orange
14 -> banana
15 -> banana
16 -> banana
17 -> banana
18 -> banana
19 -> banana
20 -> plum
21 -> Not found

Now tab in this program is a static data structure, but it is not initialized at compiletime. It is initialized in the global static initialization of your program, before main is called. Unless you have a requirement to shave nanoseconds off your program startup, I can't think why you would need the map to be initialized at compiletime.

If nevertheless you do require it to be initialized at compiletime, it is just a little fiddlier than this. You will need the map to be a constexpr object, meaning the compiler can construct it at compiletime; and for that it must be of a literal type; and that means you cannot use std::map, because it is not a literal type.

Therefore you will have to use instead:

constexpr std::pair<int,char const *> tab[] 
{
    { 2,"apple" },
    { 5,"pear" },
    { 9,"orange" },
    { 14,"banana" },
    { 20,"plum" },
    { 20 + 1,nullptr }
};   

or similar, and implement lookup(n) in essentially the manner shown, but invoking std::upper_bound upon tab. There you'll find the slightly fiddlier bits, which I'll leave you for the exercise, if you want it.

like image 143
Mike Kinghan Avatar answered Mar 19 '23 11:03

Mike Kinghan


I finally created what I wanted to achieve. It's overcomplicated and it looks like compiler optimizers are much smarter then I thought.

// Log "function"
template <int N>
struct LOG
{
    static const int value = LOG<N/2>::value + 1;
};
template<>
struct LOG<0>
{
    static const int value = 0;
};

// pow "function"
template <int N>
struct POW
{
    static const int value = POW<N-1>::value * 2;
};
template<>
struct POW<1>
{
    static const int value = 2;
};

template<>
struct POW<0>
{
    static const int value = 1;
};

// Pair <key, value> to be a payload in a type list
template<int k, char v>
struct Pair
{
    static const int key = k;
    static const int value = v;
};

// type list manipulator - access n-th element
template <size_t, class...> struct element;
template <class TT, class...TTs>
struct element<0, TT, TTs...>
{
    typedef TT type;
};
template <size_t K, class TT, class...TTs>
struct element<K, TT, TTs...>
{
    typedef typename element<K-1, TTs...>::type type;
};

template<class... Ts>
struct V;

// Binary split search algorithm (pure templatized)
template<class T, class... Ts>
struct V<T, Ts...> : private V<Ts...>
{
    template<size_t N = sizeof...(Ts), size_t level = LOG<sizeof...(Ts)+1>::value>
    struct impl
    {
        template<size_t IDX>
        inline static char search_impl(size_t n)
        {
            using namespace std;
            static const int depth = LOG<N>::value;
            static const int layer = depth - level;
            static const int key   = element<IDX, T, Ts...>::type::key;
            static const size_t left_idx  = IDX - ( N / POW<layer + 2>::value + 1);
            static const size_t right_idx =
                IDX + ( N / POW<layer + 2>::value + 1) > sizeof...(Ts) ?
                sizeof...(Ts) :
                IDX + ( N / POW<layer + 2>::value + 1);             

            //std::cout << setfill('*') << setw(layer) << ' '
            //    << "Level:" << level << " of:" << depth << std::endl
            //    << std::setw(layer) << ' ' 
            //    << "IDX/val/layer/POW/level: "
            //    << " " << IDX
            //    << "/" << key
            //    << "/" << layer
            //    << "/" << POW<layer>::value
            //    << "/" << level
            //    << "/" << left_idx
            //    << "/" << right_idx
            //    << std::endl;
            if ( n < key )
                return V<T, Ts...>::impl<N, level-1>::template search_impl<left_idx>(n);
            else
                return V<T, Ts...>::impl<N, level-1>::template search_impl<right_idx>(n);       
        }

    };

    template<size_t N>
    struct impl<N,1>
    {
        template<size_t IDX>
        inline static char search_impl(size_t n)
        {
            static const int key   = element<IDX, T, Ts...>::type::key;
            static const char value1 = element<IDX-1, T, Ts...>::type::value;
            static const char value2 = element<IDX, T, Ts...>::type::value;
            if ( n < key )
            {
                //std::cout << " *" << value1 << ' '  << IDX << std::endl;
                return value1;
            } else {
                //std::cout << " *" << value2 << ' '  << IDX << std::endl;
                return value2;
            }
        }
    };

    static void print()
    {
        std::cout << typeid(T).name() << ' ' << T::key << ' ' << (char)T::value << std::endl;
        V<Ts...>::print();
    }
    static char search(size_t n)
    {
        static const size_t level = LOG<sizeof...(Ts)+1>::value;
        static const size_t N = sizeof...(Ts);
        static const int height = LOG<N>::value;
        static const size_t root_idx = N / 2 + 1;
        static const int key = element<root_idx, T, Ts...>::type::key;
        //std::cout << "Level:" << level << " of:" << height << std::endl
        //    << "IDX/val: "
        //    << " " << root_idx
        //    << "/" << input[root_idx]
        //    << std::endl;

        static const size_t left_idx  = root_idx - ( N / POW<2>::value + 1);
        static const size_t right_idx = root_idx + ( N / POW<2>::value + 1);

        if( n < key)
            return V<T, Ts...>::impl<N, level-1>::template search_impl<left_idx>(n);
        else
            return V<T, Ts...>::impl<N, level-1>::template search_impl<right_idx>(n);
    }
};

template<>
struct V<>
{
    static void print()
    {}
};

int main(int argc, char *argv[])
{
    int i = std::stoi(argv[1]);

    typedef V<
    Pair<  0x1,'a'>,
    Pair< 0x11,'b'>,
    Pair< 0x21,'c'>,
    Pair< 0x31,'d'>,
    Pair< 0x41,'e'>,
    Pair< 0x51,'f'>,
    Pair< 0x61,'g'>,
    Pair< 0x71,'h'>,
    Pair< 0x81,'i'>,
    Pair< 0x91,'j'>,
    Pair<0x101,'k'>,
    Pair<0x111,'l'>,
    Pair<0x121,'m'>,
    Pair<0x131,'n'>,
    Pair<0x141,'o'>
    > TV;

    std::cout << (char)TV::search(i) << std::endl;

    return 0;
};

So this is it. Mine goal was to "force" the compiler to put all the constants into the code. As nothing is kept in the data segment. The resulting code inlines all the search_impl<*> methods together and the result contains only "cmp" and "jae" instructions. But it looks like a reasonable compiler will do this anyway, if the array to be searched is defined as const static.

like image 20
ibre5041 Avatar answered Mar 19 '23 13:03

ibre5041