Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Extract scalar value from SSE vector

Tags:

c

x86

simd

sse

I have a piece of code which does a comparison on array elements if they are > than a value, in SIMD-ish fashion:

void sse(uint *dst, size_t N)
{
    const __m128i condition = _mm_set1_epi32(2);

    for (uint i = 0; i < N; i += 4)
    {
        __m128i v = _mm_load_si128((__m128i *)&dst[i]);
        __m128i cmp = _mm_cmpgt_epi32(v, condition);
        v = _mm_and_si128(v, cmp);
        _mm_store_si128((__m128i *)&dst[i], v);
    }
}

Now, after the comparison, before anding elements - _mm_and_si128, I want to count elements which passed the condition, i.e. those set to '1', and store the sum in an int variable. How can I do that in SIMD? For instance, if out of four only two passed the condition, have this int var = 2.

like image 618
Nik Kovac Avatar asked Mar 23 '23 13:03

Nik Kovac


1 Answers

Typically you would keep a vector count throughout the loop and then just sum the elements of the vector when the loop terminates, e.g.

#include <emmintrin.h>

uint32_t sse(const uint32_t *dst, const size_t N)
{
    const __m128i condition = _mm_set1_epi32(2);
    __m128i vcount = _mm_set1_epi32(0);
    uint32_t count = 0;

    for (size_t i = 0; i < N; i += 4)
    {
        __m128i v = _mm_load_si128((__m128i *)&dst[i]);
        __m128i vcmp = _mm_cmpgt_epi32(v, condition);
        v = _mm_and_si128(v, vcmp);
        _mm_store_si128((__m128i *)&dst[i], v);
        vcount = _mm_add_epi32(vcount, vcmp); // accumulate (negative) counts
    }
    // ... sum vcount here and store in count (see below) ...
    return count;
}

Note that we are treating each mask element as an int, i.e. 0 or -1, and so we are accumulating a sum which is the negative of the actual sum.

Efficiency of the final vcount summation is not normally too important as it is performed only once for the entire loop, so provided N is reasonably large it does not matter how many instructions are required (within reason).

There are several ways of handling the final summation, e.g. you can use _mm_movemask_epi8 (SSE2) to extract a 16 bit mask and work with that, or you can use _mm_hadd_epi32 (SSSE3) to calculate a horizontal sum on the vector and then extract the sum as a scalar, e.g.

SSE2:

#include <emmintrin.h>

int16_t mask = _mm_movemask_epi8(vcount);       // extract 16 bit mask
count = !!(mask & 0x0001) +                     // count non-zero 32 bit elements
        !!(mask & 0x0010) + 
        !!(mask & 0x0100) + 
        !!(mask & 0x1000);

SSSE3:

#include <tmmintrin.h>

vcount = _mm_hadd_epi32(vcount, vcount);        // horizontal sum of 4 elements
vcount = _mm_hadd_epi32(vcount, vcount);
count = - ((_mm_extract_epi16(vcount, 1) << 16) // extract (and negate) sum to
          | _mm_extract_epi16(vcount, 1));      // get total (positive) count

SSE4.2:

#include <smmintrin.h>

vcount = _mm_hadd_epi32(vcount, vcount);        // horizontal sum of 4 elements
vcount = _mm_hadd_epi32(vcount, vcount);
count = - _mm_extract_epi32(vcount, 0);         // extract (and negate) sum to
                                                // get total (positive) count

Here is a complete working version with test harness for the SSE4.2 version:

#include <stdio.h>
#include <stdint.h>
#include <smmintrin.h>

uint32_t sse(const uint32_t *dst, const size_t N)
{
    const __m128i condition = _mm_set1_epi32(2);
    __m128i vcount = _mm_set1_epi32(0);
    uint32_t count = 0;

    for (size_t i = 0; i < N; i += 4)
    {
        __m128i v = _mm_load_si128((__m128i *)&dst[i]);
        __m128i vcmp = _mm_cmpgt_epi32(v, condition);
        v = _mm_and_si128(v, vcmp);
        _mm_store_si128((__m128i *)&dst[i], v);
        vcount = _mm_add_epi32(vcount, vcmp); // accumulate (negative) counts
    }

    vcount = _mm_hadd_epi32(vcount, vcount);    // horizontal sum of 4 elements
    vcount = _mm_hadd_epi32(vcount, vcount);
    count = - _mm_extract_epi32(vcount, 0);     // extract (and negate) sum to
                                                // get total (positive) count

    return count;
}

int main(void)
{
    uint32_t a[4] __attribute__ ((aligned(16))) = { 1, 2, 3, 4 };
    uint32_t count;

    count = sse(a, 4);

    printf("a = %u %u %u %u \n", a[0], a[1], a[2], a[3]);
    printf("count = %u\n", count);

    return 0;
}

$ gcc -Wall -std=c99 -msse4 sse_count.c -o sse_count
$ ./sse_count
a = 0 0 3 4 
count = 2
like image 86
Paul R Avatar answered Apr 02 '23 15:04

Paul R