Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Which is the most efficient way to extract an arbitrary range of bits from a contiguous sequence of words?

Suppose we have an std::vector, or any other sequence container (sometimes it will be a deque), which store uint64_t elements.

Now, let's see this vector as a sequence of size() * 64 contiguous bits. I need to find the word formed by the bits in a given [begin, end) range, given that end - begin <= 64 so it fits in a word.

The solution I have right now finds the two words whose parts will form the result, and separately masks and combines them. Since I need this to be as efficient as possible, I've tried to code everything without any if branch to not cause branch mispredictions, so for example the code works in both cases when the entire range fits into a word or when it spans two words, without taking different paths. To do this I needed to code those shiftl and shiftr functions, which do nothing but shifting a word by the specified amount, like the >> and << operators, but gracefully handling the case when n is greater than 64, which would be undefined behavior otherwise.

Another point is that the get() function, as coded now, works also for empty ranges, in a mathematical sense, e.g. not only if begin == end, but also if begin > end, which is required by the main algorithm that is calling this function. Again, I've tried to do this without simply branching and returning zero in that case.

However, also looking at the assembly code, all this seems far too complex to perform such a seemingly simple task. This code runs in a performance-critical algorithm, which is running a bit too slow. valgrind told us this function is called 230 million times and accounts for the 40% of the total execution time, so I would really need to make it faster.

So can you help me to find a simpler and/or more efficient way to accomplish this task? I don't care too much about portability. Solutions using x86 SIMD intrinsics (SSE3/4/AVX ecc...) or compiler builtins are ok, as far as I can compile them with both g++ and clang.

My current code is included below:

using word_type = uint64_t;
const size_t W = 64;

// Shift right, but without being undefined behaviour if n >= 64
word_type shiftr(word_type val, size_t n)
{
    uint64_t good = n < W;

    return good * (val >> (n * good));
}

// Shift left, but without being undefined behaviour if n >= 64
word_type shiftl(word_type val, size_t n)
{
    uint64_t good = n < W;

    return good * (val << (n * good));
}

// Mask the word preserving only the lower n bits.
word_type lowbits(word_type val, size_t n)
{
    word_type mask = shiftr(word_type(-1), W - n);

    return val & mask;
}

// Struct for return values of locate()
struct range_location_t {
    size_t lindex; // The word where is located the 'begin' position
    size_t hindex; // The word where is located the 'end' position
    size_t lbegin; // The position of 'begin' into its word
    size_t llen;   // The length of the lower part of the word
    size_t hlen;   // The length of the higher part of the word
};

// Locate the one or two words that will make up the result
range_location_t locate(size_t begin, size_t end)
{
    size_t lindex = begin / W;
    size_t hindex = end / W;
    size_t lbegin = begin % W;
    size_t hend   = end % W;

    size_t len = (end - begin) * size_t(begin <= end);
    size_t hlen = hend * (hindex > lindex);
    size_t llen = len - hlen;

    return { lindex, hindex, lbegin, llen, hlen };
}

// Main function.
template<typename Container>
word_type get(Container const&container, size_t begin, size_t end)
{
    assert(begin < container.size() * W);
    assert(end <= container.size() * W);

    range_location_t loc = locate(begin, end);

    word_type low = lowbits(container[loc.lindex] >> loc.lbegin, loc.llen);

    word_type high = shiftl(lowbits(container[loc.hindex], loc.hlen), loc.llen);

    return high | low;
}

Thank you very much.

like image 773
gigabytes Avatar asked Mar 17 '23 13:03

gigabytes


1 Answers

As announced in the chat, I add a refined answer. It contains three parts, each of them followed by a description of that part.

The 1st part, get.h, is my solution, but generalized and with one correction.

The 2nd part, got.h, is the original algorithm as posted in the question, generalized as well to work with any STL container of any unsigned type.

The 3rd part, main.cpp, contains unit tests which verify the correctness and measure performance.

#include <cstddef>

using std::size_t;

template < typename C >
typename C::value_type get ( C const &container, size_t lo, size_t hi )
{

   typedef typename C::value_type item; // a container entry
   static unsigned const bits = (unsigned)sizeof(item)*8u; // bits in an item
   static size_t const mask = ~(size_t)0u/bits*bits; // huge multiple of bits

   // everthing above has been computed at compile time. Now do some work:

   size_t lo_adr = (lo       ) / bits; // the index in the container of ...
   size_t hi_adr = (hi-(hi>0)) / bits; // ... the lower or higher item needed

   // we read container[hi_adr] first and possibly delete the highest bits:

   unsigned hi_shift = (unsigned)(mask-hi)%bits;
   item hi_val = container[hi_adr] << hi_shift >> hi_shift;

   // if all bits are in the same item, we delete the lower bits and are done:

   unsigned lo_shift = (unsigned)lo%bits;
   if ( hi_adr <= lo_adr ) return (hi_val>>lo_shift) * (lo<hi);

   // else we have to read the lower item as well, and combine both

   return ( hi_val<<bits-lo_shift | container[lo_adr]>>lo_shift );

}

The 1st part, get.h above, is my original solution, but generalized to work with any STL containers of unsigned integer types. Thus you can use and test it for 32-bit integers or 128-bit integers as well. I still use unsigned for very small numbers, but you may as well replace them by size_t. The algorithm is nearly unchanged, with a small correction - if lo was the total number of bits in the container, my previous get() would access an item just above the container size. This is fixed now.

#include <cstddef>

using std::size_t;

// Shift right, but without being undefined behaviour if n >= 64
template < typename val_type >
val_type shiftr(val_type val, size_t n)
{
   val_type good = n < sizeof(val_type)*8;
   return good * (val >> (n * good));
}

// Shift left, but without being undefined behaviour if n >= 64
template < typename val_type >
val_type shiftl(val_type val, size_t n)
{
   val_type good = n < sizeof(val_type)*8;
   return good * (val << (n * good));
}

// Mask the word preserving only the lower n bits.
template < typename val_type >
val_type lowbits(val_type val, size_t n)
{
    val_type mask = shiftr<val_type>((val_type)(-1), sizeof(val_type)*8 - n);
    return val & mask;
}

// Struct for return values of locate()
struct range_location_t {
   size_t lindex; // The word where is located the 'begin' position
   size_t hindex; // The word where is located the 'end' position
   size_t lbegin; // The position of 'begin' into its word
   size_t llen;   // The length of the lower part of the word
   size_t hlen;   // The length of the higher part of the word
};

// Locate the one or two words that will make up the result
template < typename val_type >
range_location_t locate(size_t begin, size_t end)
{
   size_t lindex = begin / (sizeof(val_type)*8);
   size_t hindex = end / (sizeof(val_type)*8);
   size_t lbegin = begin % (sizeof(val_type)*8);
   size_t hend   = end % (sizeof(val_type)*8);

   size_t len = (end - begin) * size_t(begin <= end);
   size_t hlen = hend * (hindex > lindex);
   size_t llen = len - hlen;

   range_location_t l = { lindex, hindex, lbegin, llen, hlen };
   return l;
}

// Main function.
template < typename C >
typename C::value_type got ( C const&container, size_t begin, size_t end )
{
   typedef typename C::value_type val_type;
   range_location_t loc = locate<val_type>(begin, end);
   val_type low = lowbits<val_type>(container[loc.lindex] >> loc.lbegin, loc.llen);
   val_type high = shiftl<val_type>(lowbits<val_type>(container[loc.hindex], loc.hlen), loc.llen);
   return high | low;
}

This 2nd part, got.h above, is the original algorithm in the question, generalized as well to accept any STL containers of any unsigned integer types. Like get.h, this version does not use any definitions except the single template parameter that defines the container type, thus it can easily be tested for other item sizes or container types.

#include <vector>
#include <cstddef>
#include <stdint.h>
#include <stdio.h>
#include <sys/time.h>
#include <sys/resource.h>
#include "get.h"
#include "got.h"

template < typename Container > class Test {

   typedef typename Container::value_type val_type;
   typedef val_type (*fun_type) ( Container const &, size_t, size_t );
   typedef void (Test::*fun_test) ( unsigned, unsigned );
   static unsigned const total_bits = 256; // number of bits in the container
   static unsigned const entry_bits = (unsigned)sizeof(val_type)*8u;

   Container _container;
   fun_type _function;
   bool _failed;

   void get_value ( unsigned lo, unsigned hi ) {
      _function(_container,lo,hi); // we call this several times ...
      _function(_container,lo,hi); // ... because we measure ...
      _function(_container,lo,hi); // ... the performance ...
      _function(_container,lo,hi); // ... of _function, ....
      _function(_container,lo,hi); // ... not the performance ...
      _function(_container,lo,hi); // ... of get_value and ...
      _function(_container,lo,hi); // ... of the loop that ...
      _function(_container,lo,hi); // ... calls get_value.
   }

   void verify ( unsigned lo, unsigned hi ) {
      val_type value = _function(_container,lo,hi);
      if ( lo < hi ) {
         for ( unsigned i=lo; i<hi; i++ ) {
            val_type val = _container[i/entry_bits] >> i%entry_bits & 1u;
            if ( val != (value&1u) ) {
               printf("lo=%d hi=%d [%d] is'nt %d\n",lo,hi,i,(unsigned)val);
               _failed = true;
            }
            value >>= 1u;
         }
      }
      if ( value ) {
         printf("lo=%d hi=%d value contains high bits set to 1\n",lo,hi);
         _failed = true;
      }
   }

   void run ( fun_test fun ) {
      for ( unsigned lo=0; lo<total_bits; lo++ ) {
         unsigned h0 = 0;
         if ( lo > entry_bits ) h0 = lo - (entry_bits+1);
         unsigned h1 = lo+64;
         if ( h1 > total_bits ) h1 = total_bits;
         for ( unsigned hi=h0; hi<=h1; hi++ ) {
            (this->*fun)(lo,hi);
         }
      }
   }

   static uint64_t time_used ( ) {
      struct rusage ru;
      getrusage(RUSAGE_THREAD,&ru);
      struct timeval t = ru.ru_utime;
      return (uint64_t) t.tv_sec*1000 + t.tv_usec/1000;
   }

public:

   Test ( fun_type function ): _function(function), _failed() {
      val_type entry;
      unsigned index = 0; // position in the whole bit array
      unsigned value = 0; // last value assigned to a bit
      static char const entropy[] = "The quick brown Fox jumps over the lazy Dog";
      do {
         if ( ! (index%entry_bits) ) entry = 0;
         entry <<= 1;
         entry |= value ^= 1u & entropy[index/7%sizeof(entropy)] >> index%7;
         ++index;
         if ( ! (index%entry_bits) ) _container.push_back(entry);
      } while ( index < total_bits );
   }

   bool correctness() {
      _failed = false;
      run(&Test::verify);
      return !_failed;
   }

   void performance() {
      uint64_t t1 = time_used();
      for ( unsigned i=0; i<999; i++ ) run(&Test::get_value);
      uint64_t t2 = time_used();
      printf("used %d ms\n",(unsigned)(t2-t1));
   }

   void operator() ( char const * name ) {
      printf("testing %s\n",name);
      correctness();
      performance();
   }

};

int main()
{
   typedef typename std::vector<uint64_t> Container;
   Test<Container> test(get<Container>); test("get");
   Test<Container> tost(got<Container>); tost("got");
}

The 3rd part, main.cpp above, contains a class of unit tests and applies them to get.h and got.h, that is, to my solution and to the original code of the question, slightly modified. The unit tests verify correctness and measure speed. They verify correctness by creating a container of 256 bits, filling it with some data, reading all possible bit sections of up to as many bits as fit into a container entry plus lots of pathological cases, and verifying the correctness of each of the results. They measure speed by again reading the same sections quite often and reporting the thread's time used in user space.

like image 142
Hans Klünder Avatar answered Apr 26 '23 18:04

Hans Klünder