Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multishift operation

How to implement without loop an operation on bitmasks, which for two bitmasks a and b of width n gives bitmask c of width 2 * n with following properties:

  • i-th bit in c is set only if there is j-th bit in a and k-th bit in b and j + k == i

C++ implementation:

#include <bitset>
#include <algorithm>
#include <iostream>

#include <cstdint>
#include <cassert>

#include <x86intrin.h>

std::uint64_t multishift(std::uint32_t a, std::uint32_t b)
{
    std::uint64_t c = 0;
    if (_popcnt32(b) < _popcnt32(a)) {
        std::swap(a, b);
    }
    assert(a != 0);
    do {
        c |= std::uint64_t{b} << (_bit_scan_forward(a) + 1);
    } while ((a &= (a - 1)) != 0); // clear least set bit
    return c;
}

int main()
{
    std::cout << std::bitset< 64 >(multishift(0b1001, 0b0101)) << std::endl; // ...0001011010
}

Can it be reimplemented without loop using some bit tricks or some x86 instructions?

like image 722
Tomilov Anatoliy Avatar asked Jan 02 '23 17:01

Tomilov Anatoliy


2 Answers

This is like a multiplication that uses OR instead of ADD. As far as I know there is no truly amazing trick. But here's a trick that actually avoids intrinsics rather than using them:

while (a) {
    c |= b * (a & -a);
    a &= a - 1;
}

This is very similar to your algorithm but uses multiplication to shift b left by trailing zero count of a, a & -a being a trick to select just the lowest set bit as a mask. As a bonus, that expression is safe to execute when a == 0, so you can unroll (and/or turn the while into do/while without precondition) without nasty edge cases showing up (which is not the case with TZCNT and shift).


pshufb could be used in parallel table-lookup mode, using a nibble of a to select a sub-table and then using that to multiply all nibbles of b by that nibble of a in one instruction. For the multiplication proper, that's 8 pshufbs max (or always 8 since there's less point in trying to early-exit with this). It does require some weird setup at the start and some unfortunate horizontal stuff to finish it off though, so it may not be so great.

like image 178
harold Avatar answered Jan 05 '23 16:01

harold


Since this question is tagged with BMI (Intel Haswell and newer, AMD Excavator and newer), I assume that AVX2 is also of interest here. Note that all processors with BMI2 support also have AVX2 support. On modern hardware, such as Intel Skylake, a brute force AVX2 approach may perform quite well in comparison with a scalar (do/while) loop, unless a or b only have very few bits set. The code below computes all the 32 multiplications (see Harold's multiplication idea) without searching for the non-zero bits in the input. After the multiplying everything is OR-ed together.

With random input values of a, the scalar codes (see above: Orient's question or Harold's answer) may suffer from branch misprediction: The number of instructions per cycle (IPC) is less than the IPC of the branchless AVX2 solution. I did a few tests with the different codes. It turns out that Harold's code may benefit significantly from loop unrolling. In these tests throughput is measured, not latency. Two cases are considered: 1. A random a with on average 3.5 bits set. 2. A random a with on average 16.0 bits set (50%). The cpu is Intel Skylake i5-6500.

Time in sec.
                             3.5 nonz      16 nonz  
mult_shft_Orient()             0.81          1.51                
mult_shft_Harold()             0.84          1.51                
mult_shft_Harold_unroll2()     0.64          1.58                
mult_shft_Harold_unroll4()     0.48          1.34                
mult_shft_AVX2()               0.44          0.40               

Note that these timing include the generation of the random input numbers a and b. The AVX2 code benefits from the fast vpmuludq instruction: it has a throughput of 2 per cycle on Skylake.


The code:

/*      gcc -Wall -m64 -O3 -march=broadwell mult_shft.c    */
#include <stdint.h>
#include <stdio.h>
#include <x86intrin.h>

uint64_t mult_shft_AVX2(uint32_t a_32, uint32_t b_32) {

    __m256i  a          = _mm256_broadcastq_epi64(_mm_cvtsi32_si128(a_32));
    __m256i  b          = _mm256_broadcastq_epi64(_mm_cvtsi32_si128(b_32));
                                                           //  0xFEDCBA9876543210  0xFEDCBA9876543210  0xFEDCBA9876543210  0xFEDCBA9876543210     
    __m256i  b_0        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000000008, 0x0000000000000004, 0x0000000000000002, 0x0000000000000001));
    __m256i  b_1        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000000080, 0x0000000000000040, 0x0000000000000020, 0x0000000000000010));
    __m256i  b_2        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000000800, 0x0000000000000400, 0x0000000000000200, 0x0000000000000100));
    __m256i  b_3        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000008000, 0x0000000000004000, 0x0000000000002000, 0x0000000000001000));
    __m256i  b_4        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000080000, 0x0000000000040000, 0x0000000000020000, 0x0000000000010000));
    __m256i  b_5        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000000800000, 0x0000000000400000, 0x0000000000200000, 0x0000000000100000));
    __m256i  b_6        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000008000000, 0x0000000004000000, 0x0000000002000000, 0x0000000001000000));
    __m256i  b_7        = _mm256_and_si256(b,_mm256_set_epi64x(0x0000000080000000, 0x0000000040000000, 0x0000000020000000, 0x0000000010000000));

    __m256i  m_0        = _mm256_mul_epu32(a, b_0);
    __m256i  m_1        = _mm256_mul_epu32(a, b_1);
    __m256i  m_2        = _mm256_mul_epu32(a, b_2);
    __m256i  m_3        = _mm256_mul_epu32(a, b_3);
    __m256i  m_4        = _mm256_mul_epu32(a, b_4);
    __m256i  m_5        = _mm256_mul_epu32(a, b_5);
    __m256i  m_6        = _mm256_mul_epu32(a, b_6);
    __m256i  m_7        = _mm256_mul_epu32(a, b_7);

             m_0        = _mm256_or_si256(m_0, m_1);     
             m_2        = _mm256_or_si256(m_2, m_3);     
             m_4        = _mm256_or_si256(m_4, m_5);     
             m_6        = _mm256_or_si256(m_6, m_7);
             m_0        = _mm256_or_si256(m_0, m_2);     
             m_4        = _mm256_or_si256(m_4, m_6);     
             m_0        = _mm256_or_si256(m_0, m_4);     

    __m128i  m_0_lo     = _mm256_castsi256_si128(m_0);
    __m128i  m_0_hi     = _mm256_extracti128_si256(m_0, 1);
    __m128i  e          = _mm_or_si128(m_0_lo, m_0_hi);
    __m128i  e_hi       = _mm_unpackhi_epi64(e, e);
             e          = _mm_or_si128(e, e_hi);
    uint64_t c          = _mm_cvtsi128_si64x(e);
                          return c; 
}


uint64_t mult_shft_Orient(uint32_t a, uint32_t b) {
    uint64_t c = 0;
    do {
        c |= ((uint64_t)b) << (_bit_scan_forward(a) );
    } while ((a = a & (a - 1)) != 0);
    return c; 
}


uint64_t mult_shft_Harold(uint32_t a_32, uint32_t b_32) {
    uint64_t c = 0;
    uint64_t a = a_32;
    uint64_t b = b_32;
    while (a) {
       c |= b * (a & -a);
       a &= a - 1;
    }
    return c; 
}


uint64_t mult_shft_Harold_unroll2(uint32_t a_32, uint32_t b_32) {
    uint64_t c = 0;
    uint64_t a = a_32;
    uint64_t b = b_32;
    while (a) {
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
    }
    return c; 
}


uint64_t mult_shft_Harold_unroll4(uint32_t a_32, uint32_t b_32) {
    uint64_t c = 0;
    uint64_t a = a_32;
    uint64_t b = b_32;
    while (a) {
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
       c |= b * (a & -a);
       a &= a - 1;
    }
    return c; 
}



int main(){

uint32_t a,b;

/*
uint64_t c0, c1, c2, c3, c4;
a = 0x10036011;
b = 0x31000107;

//a = 0x80000001;
//b = 0x80000001;

//a = 0xFFFFFFFF;
//b = 0xFFFFFFFF;

//a = 0x00000001;
//b = 0x00000001;

//a = 0b1001;
//b = 0b0101;

c0 = mult_shft_Orient(a, b);        
c1 = mult_shft_Harold(a, b);        
c2 = mult_shft_Harold_unroll2(a, b);
c3 = mult_shft_Harold_unroll4(a, b);
c4 = mult_shft_AVX2(a, b);          
printf("%016lX \n%016lX     \n%016lX     \n%016lX     \n%016lX \n\n", c0, c1, c2, c3, c4);
*/

uint32_t rnd = 0xA0036011;
uint32_t rnd_old;
uint64_t c;
uint64_t sum = 0;
double popcntsum =0.0;
int i;

for (i=0;i<100000000;i++){
   rnd_old = rnd;
   rnd = _mm_crc32_u32(rnd, i);      /* simple random generator                                   */
   b = rnd;                          /* the actual value of b has no influence on the performance */
   a = rnd;                          /* `a` has about 50% nonzero bits                            */

#if 1 == 1                           /* reduce number of set bits from about 16 to 3.5                 */
   a = rnd & rnd_old;                                   /* now `a` has about 25 % nonzero bits    */
          /*0bFEDCBA9876543210FEDCBA9876543210 */     
   a = (a & 0b00110000101001000011100010010000) | 1;    /* about 3.5 nonzero bits on average      */                  
#endif   
/*   printf("a = %08X \n", a);                */

//   popcntsum = popcntsum + _mm_popcnt_u32(a); 
                                               /*   3.5 nonz       50%   (time in sec.)  */
//   c = mult_shft_Orient(a, b );              /*      0.81          1.51                  */
//   c = mult_shft_Harold(a, b );              /*      0.84          1.51                */
//   c = mult_shft_Harold_unroll2(a, b );      /*      0.64          1.58                */
//   c = mult_shft_Harold_unroll4(a, b );      /*      0.48          1.34                */
   c = mult_shft_AVX2(a, b );                /*      0.44          0.40                */
   sum = sum + c;
}
printf("sum = %016lX \n\n", sum);

printf("average density = %f bits per uint32_t\n\n", popcntsum/100000000);

return 0;
}

Function mult_shft_AVX2() uses zero based bit indices! Just like Harold's answer. It looks like that in your question you start counting bits at 1 instead of 0. You may want to multiply the answer by 2 to get the same results.

like image 30
wim Avatar answered Jan 05 '23 17:01

wim