Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Comparing two vector<bool> with SSE

Tags:

c++

x86

simd

sse

I have two vector<bool> A and B.

I want to compare them and count the number of elements that are equal:

For example:

A = {0,1,0,1}
B = {0,0,1,1}

Result will be equal to 2.

I can use _mm_cmpeq_epi8 but it is only compare 16 elements (i.e. I should convert 0 and 1 to char and then do the comparison). Is it possible to compare 128 elements each time with SSE (or SIMD instructions)?

like image 853
user1436187 Avatar asked Dec 13 '15 23:12

user1436187


2 Answers

If you can either assume that vector<bool> is using contiguous byte-sized elements for storage, or if you can consider using something like vector<uint8_t> instead, then this example should give you a good starting point:

static size_t count_equal(const vector<uint8_t> &vec1, const vector<uint8_t> &vec2)
{
    assert(vec1.size() == vec2.size());         // vectors must be same size

    const size_t n = vec1.size();
    const size_t max_block_size = 255 * 16;     // max block size before possible overflow

    __m128i vcount = _mm_setzero_si128();
    size_t i, count = 0;

    for (i = 0; i + 16 <= n; )                  // for each block
    {
        size_t m = std::min(n, i + max_block_size);

        for ( ; i + 16 <= m; i += 16)           // for each vector in block
        {
            __m128i v1 = _mm_loadu_si128((__m128i *)&vec1[i]);
            __m128i v2 = _mm_loadu_si128((__m128i *)&vec2[i]);
            __m128i vcmp = _mm_cmpeq_epi8(v1, v2);
            vcount = _mm_sub_epi8(vcount, vcmp);
        }
        vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
        count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
        vcount = _mm_setzero_si128();           // update count from current block
    }
    vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
    count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
    for ( ; i < n; ++i)                         // deal with any remaining partial vector
    {
        count += (vec1[i] == vec2[i]);
    }
    return count;
}

Note that this is using vector<uint8_t>. If you really have to use vector<bool> and can guarantee that the elements will always be contiguous and byte-sized then you'll just need to coerce the vector<bool> into a const uint8_t * or similar somehow.

Test harness:

#include <cassert>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <vector>

#include <emmintrin.h>    // SSE2

using std::vector;

static size_t count_equal_ref(const vector<uint8_t> &vec1, const vector<uint8_t> &vec2)
{
    assert(vec1.size() == vec2.size());

    const size_t n = vec1.size();
    size_t i, count = 0;

    for (i = 0 ; i < n; ++i)
    {
        count += (vec1[i] == vec2[i]);
    }
    return count;
}

static size_t count_equal(const vector<uint8_t> &vec1, const vector<uint8_t> &vec2)
{
    assert(vec1.size() == vec2.size());         // vectors must be same size

    const size_t n = vec1.size();
    const size_t max_block_size = 255 * 16;     // max block size before possible overflow

    __m128i vcount = _mm_setzero_si128();
    size_t i, count = 0;

    for (i = 0; i + 16 <= n; )                  // for each block
    {
        size_t m = std::min(n, i + max_block_size);

        for ( ; i + 16 <= m; i += 16)           // for each vector in block
        {
            __m128i v1 = _mm_loadu_si128((__m128i *)&vec1[i]);
            __m128i v2 = _mm_loadu_si128((__m128i *)&vec2[i]);
            __m128i vcmp = _mm_cmpeq_epi8(v1, v2);
            vcount = _mm_sub_epi8(vcount, vcmp);
        }
        vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
        count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
        vcount = _mm_setzero_si128();           // update count from current block
    }
    vcount = _mm_sad_epu8(vcount, _mm_setzero_si128());
    count += _mm_extract_epi16(vcount, 0) + _mm_extract_epi16(vcount, 4);
    for ( ; i < n; ++i)                         // deal with any remaining partial vector
    {
        count += (vec1[i] == vec2[i]);
    }
    return count;
}

int main(int argc, char * argv[])
{
    size_t n = 100;

    if (argc > 1)
    {
        n = atoi(argv[1]);
    }

    vector<uint8_t> vec1(n);
    vector<uint8_t> vec2(n);

    srand((unsigned int)time(NULL));

    for (size_t i = 0; i < n; ++i)
    {
        vec1[i] = rand() & 1;
        vec2[i] = rand() & 1;
    }

    size_t n_ref = count_equal_ref(vec1, vec2);
    size_t n_test = count_equal(vec1, vec2);

    if (n_ref == n_test)
    {
        std::cout << "PASS" << std::endl;
    }
    else
    {
        std::cout << "FAIL: n_ref = " << n_ref << ", n_test = " << n_test << std::endl;
    }

    return 0;
}

Compile and run:

$ g++ -Wall -msse3 -O3 test.cpp && ./a.out
PASS
like image 162
Paul R Avatar answered Oct 03 '22 19:10

Paul R


std::vector<bool> is a specialization of std::vector for the type bool. Although not specified by the C++ standard, in most implementations std::vector<bool> is made space efficient such that each of its element is a single bit instead of a bool.

The behaviour of std::vector<bool> is similar to its primarily template counterpart, except that:

  1. std::vector<bool> does not necessarily store its element contiguously .
  2. In order to expose its elements (i.e., the individual bits) std::vector<bool> uses a proxy class (i.e., std::vector<bool>::reference). Objects of class std::vector<bool>::reference are returned by std::vector<bool> subscript operator (i.e., operator[]) by value.

Accordingly, I don't think it's portable to use _mm_cmpeq_epi8 like functions since storage of a std::vector<bool> is implementation defined (i.e., not guaranteed contiguous).

An alternative but portable way is to use regular STL facilities like the example below:

std::vector<bool> A = {0,1,0,1};
std::vector<bool> B = {0,0,1,1};
std::vector<bool> C(A.size());
std::transform(A.begin(), A.end(), B.begin(), C.begin(), [](bool const &a, bool const &b) { return a == b;});
std::cout << std::count(C.begin(), C.end(), true) << std::endl;

Live Demo

like image 41
101010 Avatar answered Oct 03 '22 17:10

101010