Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

The indices of non-zero bytes of an SSE/AVX register

Tags:

c++

c

avx

simd

sse

If an SSE/AVX register's value is such that all its bytes are either 0 or 1, is there any way to efficiently get the indices of all non zero elements?

For example, if xmm value is | r0=0 | r1=1 | r2=0 | r3=1 | r4=0 | r5=1 | r6=0 |...| r14=0 | r15=1 | the result should be something like (1, 3, 5, ... , 15). The result should be placed in another _m128i variable or char[16] array.

If it helps, we can assume that register's value is such that all bytes are either 0 or some constant nonzero value (not necessary 1).

I am pretty much wondering if there is an instruction for that or preferably C/C++ intrinsic. In any SSE or AVX set of instructions.

EDIT 1:

It was correctly observed by @zx485 that original question was not clear enough. I was looking for any "consecutive" solution.

The example 0 1 0 1 0 1 0 1... above should result in either of the following:

  • If we assume that indices start from 1, then 0 would be a termination byte and the result might be

002 004 006 008 010 012 014 016 000 000 000 000 000 000 000 000

  • If we assume that negative byte is a termination byte the result might be

001 003 005 007 009 011 013 015 0xFF 0xFF 0xFF 0xFF 0xFF 0xFF 0xFF 0xFF

  • Anything, that gives as a consecutive bytes which we can interpret as indices of non-zero elements in the original value

EDIT 2:

Indeed, as @harold and @Peter Cordes suggest in the comments to the original post, one of the possible solutions is to create a mask first (e.g. with pmovmskb) and check non zero indices there. But that will lead to a loop.

like image 330
TruLa Avatar asked Feb 28 '16 10:02

TruLa


2 Answers

Your question was unclear regarding the aspect if you want the result array to be "compressed". What I mean by "compressed" is, that the result should be consecutive. So, for example for 0 1 0 1 0 1 0 1..., there are two possibilities:

Non-consecutive:

XMM0: 000 001 000 003 000 005 000 007 000 009 000 011 000 013 000 015

Consecutive:

XMM0: 001 003 005 007 009 011 013 015 000 000 000 000 000 000 000 000

One problem of the consecutive approach is: how do you decide if it's index 0 or a termination value?

I'm offering a simple solution to the first, non-consecutive approach, which should be quite fast:

.data
  ddqZeroToFifteen              db 0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15
  ddqTestValue:                 db 0,1,0,1,0,1,0,1,0,1,0,1,0,1,0,1
.code
  movdqa xmm0, xmmword ptr [ddqTestValue]
  pxor xmm1, xmm1                             ; zero XMM1
  pcmpeqb xmm0, xmm1                          ; set to -1 for all matching
  pandn xmm0, xmmword ptr [ddqZeroToFifteen]  ; invert and apply indices

Just for the sake of completeness: the second, the consecutive approach, is not covered in this answer.

like image 126
zx485 Avatar answered Sep 22 '22 09:09

zx485


Updated answer: the new solution is slightly more efficient.

You can do this without a loop by using the pext instruction from the Bit Manipulation Instruction Set 2 , in combination with a few other SSE instructions.

/*
gcc -O3 -Wall -m64 -mavx2 -march=broadwell ind_nonz_avx.c
*/

#include <stdio.h>
#include <immintrin.h>
#include <stdint.h>

__m128i nonz_index(__m128i x){
   /* Set some constants that will (hopefully) be hoisted out of a loop after inlining. */
   uint64_t  indx_const   = 0xFEDCBA9876543210;                       /* 16 4-bit integers, all possible indices from 0 o 15                                                            */
   __m128i   cntr         = _mm_set_epi8(64,60,56,52,48,44,40,36,32,28,24,20,16,12,8,4);
   __m128i   pshufbcnst   = _mm_set_epi8(0x80,0x80,0x80,0x80,0x80,0x80,0x80,0x80,  0x0E,0x0C,0x0A,0x08,0x06,0x04,0x02,0x00);
   __m128i   cnst0F       = _mm_set1_epi8(0x0F);

   __m128i   msk          = _mm_cmpeq_epi8(x,_mm_setzero_si128());    /* Generate 16x8 bit mask.                                                                                        */
             msk          = _mm_srli_epi64(msk,4);                    /* Pack 16x8 bit mask to 16x4 bit mask.                                                                           */
             msk          = _mm_shuffle_epi8(msk,pshufbcnst);         /* Pack 16x8 bit mask to 16x4 bit mask, continued.                                                                */
   uint64_t  msk64        = ~ _mm_cvtsi128_si64x(msk);                 /* Move to general purpose register and invert 16x4 bit mask.                                                     */

                                                                      /* Compute the termination byte nonzmsk separately.                                                               */
   int64_t   nnz64        = _mm_popcnt_u64(msk64);                    /* Count the nonzero bits in msk64.                                                                               */
   __m128i   nnz          = _mm_set1_epi8(nnz64);                     /* May generate vmovd + vpbroadcastb if AVX2 is enabled.                                                          */
   __m128i   nonzmsk      = _mm_cmpgt_epi8(cntr,nnz);                 /* nonzmsk is a mask of the form 0xFF, 0xFF, ..., 0xFF, 0, 0, ...,0 to mark the output positions without an index */

   uint64_t  indx64       = _pext_u64(indx_const,msk64);              /* parallel bits extract. pext shuffles indx_const such that indx64 contains the nnz64 4-bit indices that we want.*/
   __m128i   indx         = _mm_cvtsi64x_si128(indx64);               /* Use a few integer instructions to unpack 4-bit integers to 8-bit integers.                                     */
   __m128i   indx_024     = indx;                                     /* Even indices.                                                                                                  */
   __m128i   indx_135     = _mm_srli_epi64(indx,4);                   /* Odd indices.                                                                                                   */
             indx         = _mm_unpacklo_epi8(indx_024,indx_135);     /* Merge odd and even indices.                                                                                    */
             indx         = _mm_and_si128(indx,cnst0F);               /* Mask out the high bits 4,5,6,7 of every byte.                                                                  */

             return _mm_or_si128(indx,nonzmsk);                       /* Merge indx with nonzmsk .                                                                                      */
}


int main(){
   int i;
   char w[16],xa[16];
   __m128i x;

   /* Example with bytes 15, 12, 7, 5, 4, 3, 2, 1, 0 set. */
   x = _mm_set_epi8(1,0,0,1,  0,0,0,0,  1,0,1,1,  1,1,1,1);   

   /* Other examples. */
   /* 
   x = _mm_set_epi8(1,1,1,1,  1,1,1,1, 1,1,1,1, 1,1,1,1);   
   x = _mm_set_epi8(0,0,0,0,  0,0,0,0, 0,0,0,0, 0,0,0,0);   
   x = _mm_set_epi8(1,0,0,0,  0,0,0,0, 0,0,0,0, 0,0,0,0);   
   x = _mm_set_epi8(0,0,0,0,  0,0,0,0, 0,0,0,0, 0,0,0,1);   
   */   
   __m128i indices = nonz_index(x);
   _mm_storeu_si128((__m128i *)w,indices);
   _mm_storeu_si128((__m128i *)xa,x);

   printf("counter 15..0 ");for (i=15;i>-1;i--) printf(" %2d ",i);      printf("\n\n");
   printf("example xmm:  ");for (i=15;i>-1;i--) printf(" %2d ",xa[i]);  printf("\n");
   printf("result in dec ");for (i=15;i>-1;i--) printf(" %2hhd ",w[i]); printf("\n");
   printf("result in hex ");for (i=15;i>-1;i--) printf(" %2hhX ",w[i]); printf("\n");

   return 0;
}

It takes about five instructions to get 0xFF (the termination byte) at the unwanted positions. Note that a function nonz_index that returns the indices and only the position of the termination byte, without actually inserting the termination byte(s), would be much cheaper to compute and might be as suitable in a particular application. The position of the first termination byte is nnz64>>2.

The result is:

$ ./a.out
counter 15..0  15  14  13  12  11  10   9   8   7   6   5   4   3   2   1   0 

example xmm:    1   0   0   1   0   0   0   0   1   0   1   1   1   1   1   1 
result in dec  -1  -1  -1  -1  -1  -1  -1  15  12   7   5   4   3   2   1   0 
result in hex  FF  FF  FF  FF  FF  FF  FF   F   C   7   5   4   3   2   1   0 

The pext instruction is supported on Intel Haswell processors or newer.

like image 29
wim Avatar answered Sep 26 '22 09:09

wim