Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Something faster than std::nth_element

I'm working on a kd-tree implementation and I'm currently using std::nth_element for partition a vector of elements by their median. However std::nth_element takes 90% of the time of tree construction. Can anyone suggest a more efficient alternative?

Thanks in advance

like image 472
plasmacel Avatar asked May 26 '15 14:05

plasmacel


2 Answers

Do you really need the nth element, or do you need an element "near" the middle?

There are faster ways to get an element "near" the middle. One example goes roughly like:

function rough_middle(container)
  divide container into subsequences of length 5
  find median of each subsequence of length 5 ~ O(k) * O(n/5)
  return rough_middle( { median of each subsequence} ) ~ O(rough_middle(n/5))

The result should be something that is roughly in the middle. A real nth element algorithm might use something like the above, and then clean it up afterwards to find the actual nth element.

At n=5, you get the middle.

At n=25, you get the middle of the short sequence middles. This is going to be greater than all of the lesser of each short sequence, or at least the 9th element and no more than the 16th element, or 36% away from edge.

At n=125, you get the rough middle of each short sequence middle. This is at least the 9th middle, so there are 8*3+2=26 elements less than your rough middle, or 20.8% away from edge.

At n=625, you get the rough middle of each short sequence middle. This is at least the 26th middle, so there are 77 elements less than your rough middle, or 12% away from the edge.

At n=5^k, you get the rough middle of the 5^(k-1) rough middles. If the rough middle of a 5^k sequence is r(k), then r(k+1) = r(k)*3-1 ~ 3^k.

3^k grows slower than 5^k in O-notation.

3^log_5(n)
= e^( ln(3) ln(n)/ln(5) )
= n^(ln(3)/ln(5))
=~ n^0.68

is a very rough estimate of the lower bound of where the rough_middle of a sequence of n elements ends up.

In theory, it may take as many as approx n^0.33 iterations of reductions to reach a single element, which isn't really that good. (the number of bits in n^0.68 is ~0.68 times the number of bits in n. If we shave that much off each rough middle, we need to repeat it very roughly n^0.33 times number of bits in n to consume all the bits -- more, because as we subtract from the n, the next n gets a slightly smaller value subtracted from it).

The way that the nth element solutions I've seen solve this is by doing a partition and repair at each level: instead of recursing into rough_middle, you recurse into middle. The real middle of the medians is then guaranteed to be pretty close to the actual middle of your sequence, and you can "find the real middle" relatively quickly (in O-notation) from this.

Possibly we can optimize this process by doing a more accurate rough_middle iterations when there are more elements, but never forcing it to be the actual middle? The bigger the end n is, the closer to the middle we need the recursive calls to be to the middle for the end result to be reasonably close to the middle.

But in practice, the probability that your sequence is a really bad one that actually takes n^0.33 steps to partition down to nothing might be really low. Sort of like the quicksort problem: median of 3 elements is usually good enough.


A quick stats analysis.

You pick 5 elements at random, and pick the middle one.

The median index of a set of 2m+1 random sample of a uniform distribution follows the beta distribution with parameters of roughly (m+1, m+1), with maybe some scaling factors for non-[0,1] intervals.

The mean of the median is clearly 1/2. The variance is:

(3*3)^2 / ( (3+3)^2 (3+3+1) )
= 81 / (36 * 7)
=~ 0.32

Figuring out the next step is beyond my stats. I'll cheat.

If we imagine that taking the median index element from a bunch of items with mean 0.5 and variance 0.32 is as good as averaging their index...

Let n now be the number of elements in our original set.

Then the sum of the indexes of medians of the short sequences has an average of n times n/5*0.5 = 0.1 * n^2. The variance of the sum of the indexes of the medians of the short sequences is n times n/5*0.32 = 0.064 * n^2.

If we then divide the value by n/5 we get:

So mean of n/2 and variance of 1.6.

Oh, if that was true, that would be awesome. Variance that doesn't grow with the size of n means that as n gets large, the average index of the medians of the short sequences gets ridiculously tightly distributed. I guess it makes some sense. Sadly, we aren't quite doing that -- we want the distribution of the pseudo-median of the medians of the short sequences. Which is almost certainly worse.


Implementation detail. We can with logarithmic number of memory overhead do an in-place rough median. (we might even be able to do it without the memory overhead!)

We maintain a vector of 5 indexes with a "nothing here" placeholder.

Each is a successive layer.

At each element, we advance the bottom index. If it is full, we grab the median, and insert it on the next level up, and clear the bottom layer.

At the end, we complete.

using target = std::pair<size_t,std::array<size_t, 5>>;
bool push( target& t, size_t i ) {
  t.second[t.first]=i;
  ++t.first;
  if (t.first==5)
    return true;
}
template<class Container>
size_t extract_median( Container const& c, target& t ) {
  Assert(t.first != 0);
  std::sort( t.data(), t.data()+t.first, [&c](size_t lhs, size_t rhs){
    return c[lhs]<c[rhs];
  } );
  size_t r = t[(t.first+1)/2];
  t.first = 0;
  return r;
}
template<class Container>
void advance(Container const& c, std::vector<target>& targets, size_t i) {
  size_t height = 0;
  while(true) {
    if (targets.size() <= height)
      targets.push_back({});
    if (!push(targets[height], i))
      return;
    i = extract_median(c, targets[height]);
  }
}
template<class Container>
size_t collapse(Container const& c, target* b, target* e) {
  if (b==e) return -1;
  size_t before = collapse(c, b, e-1);
  target& last = (*e-1);
  if (before!=-1)
    push(before, last);
  if (last.first == 0)
    return -1;
  return extract_median(c, last);
}
template<class Container>
size_t rough_median_index( Container const& c ) {
  std::vector<target> targets;
  for (auto const& x:c) {
    advance(c, targets, &x-c.data());
  }
  return collapse(c, targets.data(), targets.data()+targets.size());
}

which sketches out how it could work on random access containers.

like image 118
Yakk - Adam Nevraumont Avatar answered Nov 08 '22 11:11

Yakk - Adam Nevraumont


If you have more lookups than insertions into the vector you could consider using a data structure which sorts on insertion -- such as std::set -- and then use std::advance() to get the n'th element in sorted order.

like image 31
Frerich Raabe Avatar answered Nov 08 '22 10:11

Frerich Raabe