Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

SIMD - AVX - masking with non-zero value instead of highest bit

Tags:

c

avx

simd

I have AVX (no AVX2 or AVX-512). I have a vector with 32bit values (only 4 lowest bits are used, rest is always zero):

[ 1010, 0000, 0000, 0000, 0000, 1010, 1010, 0000]

Internally, I keep vector as __m256 because of bitwise operations and the bits represents "float numbers". I need to export single 8-bit number from the vector, which will contain 1 for non-zero a 0 for zero bits.

So for above example, I need 8-bit number: 10000110

I have idea to use _mm256_cmp_ps and then _mm256_movemask_ps. However, for cmp, I dont know if it will work correctly, if numbers are not exactly floats and can be any "junk". In this case, which operand to use for cmp?

Or is there any other solution?

like image 583
Martin Perry Avatar asked Aug 22 '19 13:08

Martin Perry


2 Answers

Conceptually, what you're doing should work. Floats with the upper 24 bits zero are valid floats. However, they are denormal.

While it should work, there are two potential problems:

  1. If the FP mode is set to set to flush denormals to zero, then they will all be treated as zero. (thus, breaking that approach)
  2. Because these are denormals, you may end up taking massive performance penalties depending on whether the hardware can handle them natively.

Alternative Approach:

Since the upper 24 bits are zero, you can normalize them. Then do the floating-point comparison.

(Warning: untested code)

int to_mask(__m256 data){
    const __m256 MASK = _mm256_set1_ps(8388608.);  //  2^23
    data = _mm256_or_ps(data, MASK);
    data = _mm256_cmp_ps(data, MASK, _CMP_NEQ_UQ);
    return _mm256_movemask_ps(data);
}

Here, data is your input where the upper 24 bits of each "float" are zero. Let's call each of these 8-bit integers x.

OR'ing with 2^23 sets the mantissa of the float such that it becomes a normalized float with value 2^23 + x.

Then you compare against 2^23 as float - which will give a 1 only if the x is non-zero.

like image 134
Mysticial Avatar answered Oct 19 '22 20:10

Mysticial


Alternate answer, for future readers that do have AVX2

You can cast to __m256i and use SIMD integer compares.

This avoid any problems with DAZ treating these small-integer bit patterns as exactly zero, or microcode assists for denormal (aka subnormal) inputs.

There might be 1 extra cycle of bypass latency between vcmpeqd and vpmovmskps on some CPUs, but you still come out ahead because integer compare is lower latency than FP compare.

int nonzero_positions_avx2(__m256 v)
{
    __m256i vi = _mm256_castps_si256(v);
    vi = _mm256_cmpeq_epi32(vi, _mm256_setzero_si256());
    return _mm256_movemask_ps(_mm256_castsi256_ps(vi));
}
like image 33
Peter Cordes Avatar answered Oct 19 '22 20:10

Peter Cordes