Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Count leading zero bits for each element in AVX2 vector, emulate _mm256_lzcnt_epi32

With AVX512, there is the intrinsic _mm256_lzcnt_epi32, which returns a vector that, for each of the 8 32-bit elements, contains the number of leading zero bits in the input vector's element.

Is there an efficient way to implement this using AVX and AVX2 instructions only?

Currently I'm using a loop which extracts each element and applies the _lzcnt_u32 function.


Related: to bit-scan one large bitmap, see Count leading zeros in __m256i word which uses pmovmskb -> bitscan to find which byte to do a scalar bitscan on.

This question is about doing 8 separate lzcnts on 8 separate 32-bit elements when you're actually going to use all 8 results, not just select one.

like image 365
tmlen Avatar asked Nov 12 '19 16:11

tmlen


2 Answers

float represents numbers in an exponential format, so int->FP conversion gives us the position of the highest set bit encoded in the exponent field.

We want int->float with magnitude rounded down (truncate the value towards 0), not the default rounding of nearest. That could round up and make 0x3FFFFFFF look like 0x40000000. If you're doing a lot of these conversions without doing any FP math, you could set the rounding mode in the MXCSR1 to truncation then set it back when you're done.

Otherwise you can use v & ~(v>>8) to keep the 8 most-significant bits and zero some or all lower bits, including a potentially-set bit 8 below the MSB. That's enough to ensure all rounding modes never round up to the next power of two. It always keeps the 8 MSB because v>>8 shifts in 8 zeros, so inverted that's 8 ones. At lower bit positions, wherever the MSB is, 8 zeros are shifted past there from higher positions, so it will never clear the most significant bit of any integer. Depending on how set bits below the MSB line up, it might or might not clear more below the 8 most significant.

After conversion, we use an integer shift on the bit-pattern to bring the exponent (and sign bit) to the bottom and undo the bias with a saturating subtract. We use min to set the result to 32 if no bits were set in the original 32-bit input.

__m256i avx2_lzcnt_epi32 (__m256i v) {
    // prevent value from being rounded up to the next power of two
    v = _mm256_andnot_si256(_mm256_srli_epi32(v, 8), v); // keep 8 MSB

    v = _mm256_castps_si256(_mm256_cvtepi32_ps(v)); // convert an integer to float
    v = _mm256_srli_epi32(v, 23); // shift down the exponent
    v = _mm256_subs_epu16(_mm256_set1_epi32(158), v); // undo bias
    v = _mm256_min_epi16(v, _mm256_set1_epi32(32)); // clamp at 32

    return v;
}

Footnote 1: fp->int conversion is available with truncation (cvtt), but int->fp conversion is only available with default rounding (subject to MXCSR).

AVX512F introduces rounding-mode overrides for 512-bit vectors which would solve the problem, __m512 _mm512_cvt_roundepi32_ps( __m512i a, int r);. But all CPUs with AVX512F also support AVX512CD so you could just use _mm512_lzcnt_epi32. And with AVX512VL, _mm256_lzcnt_epi32

like image 187
aqrit Avatar answered Nov 05 '22 06:11

aqrit


@aqrit's answer looks like a more-clever use of FP bithacks. My answer below is based on the first place I looked for a bithack which was old and aimed at scalar so it didn't try to avoid double (which is wider than int32 and thus a problem for SIMD).

It uses HW signed int->float conversion and saturating integer subtracts to handle the MSB being set (negative float), instead of stuffing bits into a mantissa for manual uint->double. If you can set MXCSR to round down across a lot of these _mm256_lzcnt_epi32, that's even more efficient.


https://graphics.stanford.edu/~seander/bithacks.html#IntegerLogIEEE64Float suggests stuffing integers into the mantissa of a large double, then subtracting to get the FPU hardware to get a normalized double. (I think this bit of magic is doing uint32_t -> double, with the technique @Mysticial explains in How to efficiently perform double/int64 conversions with SSE/AVX? (which works for uint64_t up to 252-1)

Then grab the exponent bits of the double and undo the bias.

I think integer log2 is the same thing as lzcnt, but there might be an off-by-1 at powers of 2.

The Standford Graphics bithack page lists other branchless bithacks you could use that would probably still be better than 8x scalar lzcnt.

If you knew your numbers were always small-ish (like less than 2^23) you could maybe do this with float and avoid splitting and blending.

  int v; // 32-bit integer to find the log base 2 of
  int r; // result of log_2(v) goes here
  union { unsigned int u[2]; double d; } t; // temp

  t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] = 0x43300000;
  t.u[__FLOAT_WORD_ORDER!=LITTLE_ENDIAN] = v;
  t.d -= 4503599627370496.0;
  r = (t.u[__FLOAT_WORD_ORDER==LITTLE_ENDIAN] >> 20) - 0x3FF;

The code above loads a 64-bit (IEEE-754 floating-point) double with a 32-bit integer (with no paddding bits) by storing the integer in the mantissa while the exponent is set to 252. From this newly minted double, 252 (expressed as a double) is subtracted, which sets the resulting exponent to the log base 2 of the input value, v. All that is left is shifting the exponent bits into position (20 bits right) and subtracting the bias, 0x3FF (which is 1023 decimal).

To do this with AVX2, blend and shift+blend odd/even halves with set1_epi32(0x43300000) and _mm256_castps_pd to get a __m256d. And after subtracting, _mm256_castpd_si256 and shift / blend the low/high halves into place then mask to get the exponents.

Doing integer operations on FP bit-patterns is very efficient with AVX2, just 1 cycle of extra latency for a bypass delay when doing integer shifts on the output of an FP math instruction.

(TODO: write it with C++ intrinsics, edit welcome or someone else could just post it as an answer.)


I'm not sure if you can do anything with int -> double conversion and then reading the exponent field. Negative numbers have no leading zeros and positive numbers give an exponent that depends on the magnitude.

If you did want that, you'd go one 128-bit lane at a time, shuffling to feed xmm -> ymm packed int32_t -> packed double conversion.

like image 3
Peter Cordes Avatar answered Nov 05 '22 06:11

Peter Cordes