Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast modulo-12 algorithm for 4 uint16_t's packed in a uint64_t

Consider the following union:

union Uint16Vect {
    uint16_t _comps[4];
    uint64_t _all;
};

Is there a fast algorithm for determining whether each component equals 1 modulo 12 or not?

A naive sequence of code is:

Uint16Vect F(const Uint16Vect a) {
    Uint16Vect r;
    for (int8_t k = 0; k < 4; k++) {
        r._comps[k] = (a._comps[k] % 12 == 1) ? 1 : 0;
    }
    return r;
}
like image 569
Serge Rogatch Avatar asked Feb 16 '19 17:02

Serge Rogatch


2 Answers

Compilers will optimize division by a constant to a multiplication by the reciprocal or multiplicative inverse. For example x/12 will be optimized to x*43691 >> 19

bool h(uint16_t x)
{
    return x % 12 == 1;
}
h(unsigned short):
        movzx   eax, di
        imul    eax, eax, 43691 ; = 0xFFFF*8/12 + 1
        shr     eax, 19
        lea     eax, [rax+rax*2]
        sal     eax, 2
        sub     edi, eax
        cmp     di, 1
        sete    al
        ret

Because there are multiplication instructions in SSE/AVX, this can easily be vectorized. Besides, x = (x % 12 == 1) ? 1 : 0; can be simplified to x = (x % 12 == 1) and then transformed to x = (x - 1) % 12 == 0 which avoids a load of the value 1 from the constant table to compare. You can use the vector extension so that gcc automatically generates code for you

typedef uint16_t ymm32x2 __attribute__((vector_size(32)));
ymm32x2 mod12(ymm32x2 x)
{
    return !!((x - 1) % 12);
}

Below is the output from gcc

mod12(unsigned short __vector(16)):
        vpcmpeqd    ymm3, ymm3, ymm3  ; ymm3 = -1
        vpaddw      ymm0, ymm0, ymm3
        vpmulhuw    ymm1, ymm0, YMMWORD PTR .LC0[rip] ; multiply with 43691
        vpsrlw      ymm2, ymm1, 3
        vpsllw      ymm1, ymm2, 1
        vpaddw      ymm1, ymm1, ymm2
        vpsllw      ymm1, ymm1, 2
        vpcmpeqw    ymm0, ymm0, ymm1
        vpandn      ymm0, ymm0, ymm3
        ret

Clang and ICC don't support !! on vector types so you need to change to (x - 1) % 12 == 0. Unfortunately it seems that compilers don't support __attribute__((vector_size(8)) to emit MMX instructions. But nowadays you should use SSE or AVX anyway

The output for x % 12 == 1 is shorter as you can see in the same Godbolt link above, but you need a table containing 1s to compare, which may be better or not. It's also possible that the compiler couldn't optimize fully as hand-written code so you can try to vectorize the code manually using intrinsics. Check which one works faster in your case

A better way is ((x * 43691) & 0x7ffff) < 43691, or x * 357913942 < 357913942 as mentioned in nwellnhof's answer which should also be easy to vectorize


Alternatively for a small input range like this you can use a lookup table. The basic version needs a 65536-element array

#define S1(x) ((x) + 0) % 12 == 1, ((x) + 1) % 12 == 1, ((x) + 2) % 12 == 1, ((x) + 3) % 12 == 1, \
              ((x) + 4) % 12 == 1, ((x) + 4) % 12 == 1, ((x) + 6) % 12 == 1, ((x) + 7) % 12 == 1
#define S2(x) S1((x + 0)*8), S1((x + 1)*8), S1((x + 2)*8), S1((x + 3)*8), \
              S1((x + 4)*8), S1((x + 4)*8), S1((x + 6)*8), S1((x + 7)*8)
#define S3(x) S2((x + 0)*8), S2((x + 1)*8), S2((x + 2)*8), S2((x + 3)*8), \
              S2((x + 4)*8), S2((x + 4)*8), S2((x + 6)*8), S2((x + 7)*8)
#define S4(x) S3((x + 0)*8), S3((x + 1)*8), S3((x + 2)*8), S3((x + 3)*8), \
              S3((x + 4)*8), S3((x + 4)*8), S3((x + 6)*8), S3((x + 7)*8)

bool mod12e1[65536] = {
    S4(0U), S4(8U), S4(16U), S4(24U), S4(32U), S4(40U), S4(48U), S4(56U)
}

To use just replace x % 12 == 1 with mod12e1[x]. This can of course be vectorized

But since the result is only 1 or 0, you can also use a 65536-bit array to reduce the size to only 8KB


You can also check divisibility by 12 by divisibility by 4 and 3. Divisibility by 4 is obviously trivial. Divisibility by 3 can be calculated by multiple ways

  • One is calculating the difference between the sum of the odd digits and the sum of the even digits like in גלעד ברקן's answer and check if it's divisible by 3 or not

  • Or you can check whether the sum of the digits in base 22k (like base 4, 16, 64...) to see if it's divisible by 3 or not.

    That works because in base b to check divisibility of any divisors n of b - 1, just check if the sum of the digits is divisible by n or not. Here's an implementation of it

      void modulo12equals1(uint16_t d[], uint32_t size) {
          for (uint32_t i = 0; i < size; i++)
          {
              uint16_t x = d[i] - 1;
              bool divisibleBy4 = x % 4 == 0;
              x = (x >> 8) + (x & 0x00ff); // max 1FE
              x = (x >> 4) + (x & 0x000f); // max 2D
              bool divisibleBy3 = !!((01111111111111111111111ULL >> x) & 1);
              d[i] = divisibleBy3 && divisibleBy4;
          }
      }
    

Credits for the divisibility by 3 to Roland Illig

Since the auto-vectorized assembly output is too long, you can check it on the Godbolt link

See also

  • How to know if a binary number divides by 3?
  • Determine whether or not a binary number is divisible by 3
  • Bit representation and divisibility by 3
  • building circuit for divisibility by 3
  • Check if a number is divisible by 3
  • Logic to check the number is divisible by 3 or not?
like image 133
phuclv Avatar answered Sep 28 '22 02:09

phuclv


If it would help to limit operations to bit operations and popcount, we can observe that a valid candidate must pass two tests since subtracting 1 must mean divisibility by 4 and 3. First, the last two bits must be 01. Then divisibility by 3, which we can find by subtracting the odd-positioned popcount from the even-positioned popcount.

const evenMask = parseInt('1010101010101010', 2);
// Leave out first bit, we know it will be zero
// after subtracting 1
const oddMask = parseInt('101010101010100', 2);

console.log('n , Test 1: (n & 3)^3, Test 2: popcount diff:\n\n');

for (let n=0; n<500; n++){
  if (n % 12 == 1)
    console.log(
      n,
      (n & 3)^3,
      popcount(n & evenMask) - popcount(n & oddMask))
}

// https://stackoverflow.com/questions/43122082/efficiently-count-the-number-of-bits-in-an-integer-in-javascript
function popcount(n) {
  var tmp = n;
  var count = 0;
  while (tmp > 0) {
    tmp = tmp & (tmp - 1);
    count++;
  }
  return count;
}
like image 39
גלעד ברקן Avatar answered Sep 28 '22 00:09

גלעד ברקן