Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to check overflow for multiplication of 16 bit integers in SSE?

I want to implement a simple function in SSE (a program like Izhikevich spiking neuron model). It should work with 16 bit signed integers (8.8 fixed point) and it needs to check the overflow condition during some integration step, and set a SSE mask (if overflow occured):

// initialized like following:
short I = 0x1BAD; // current injected to neuron
short vR = 0xF00D; // some reset threshold when spiked (negative)

// step to be vectorized:
short v0 = vReset;
for(;;) {

    // v0*v0/16 likely overflows => use 32 bit (16.16)
    short v0_sqr = ((int)v0)*((int)v0) / (1<<(8+4)); // not sure how "(v0*v0)>>(8+4)" would affect sign..
     // or   ((int)v0)*((int)v0) >> (8+4); // arithmetic right shift
     // original paper used v' = (v0^2)/25 + ...

    short v1 = v0_sqr + v0 + I;
    int m; // mask is set when neuron fires
    if(v1_overflows_during_this_operation()) { // "v1 > 0x7FFF" - way to detect?
        m=0xFFFFFFFF;
    else
        m=0;
    v0 = ( v1 & ~m ) | (vR & m );
}

But I haven't found the _mm_mul_epi16() instruction, to check high word of the multiplication. Why, and how such task v1_overflows_during_this_operation() is supposed to be implemented in SSE?

like image 371
xakepp35 Avatar asked Oct 01 '18 20:10

xakepp35


1 Answers

Unlike 32x32 => 64, there is no widening 16x16 -> 32 SSE multiplication instruction.

Instead, there's _mm_mulhi_epi16 and _mm_mulhi_epu16 which give you just the signed or unsigned upper half of the full result.

(and _mm_mullo_epi16, which does packed 16x16 => 16-bit low half truncating multiply, which is the same for signed or unsigned).

You could use _mm_unpacklo/hi_epi16 to interleave low/high halves into a pair of vectors with 32-bit elements, but that would be pretty slow. But yes, you could _mm_srai_epi32(v, 8+4) arithmetic right-shift that by 12 and then re-pack, maybe with _mm_packs_epi32 (signed saturation back to 16-bit). Then I guess check for saturation?


Your use case is unusual. There's _mm_mulhrs_epi16 which gives you the high 17 bits, rounded off and then truncated to 16 bits. (See the description). That's useful for some fixed-point algorithms where inputs are scaled to put the result in the upper half, and you want to round including the low half instead of truncate.

You might actually use _mm_mulhrs_epi16 or _mm_mulhi_epi16 as your best bet for keeping the most precision, maybe by left-shifting your v0 before squaring to just the point where the high half will give you (v0*v0) >> (8+4).

So do you think it is easier not to allow result to overflow, and just to generate mask with _mm_cmpge_epi16(v1, vThreshold) as author does in the original paper?

Hell yes! gaining another bit or two of precision would cost you maybe a factor of 2 in performance, because you'd have to compute another multiply result to check for overflow, or effectively widen to 32-bit (cutting the number of elements per vector in half), as described above.

With a compare result, v0 = ( v1 & ~m ) | (vR & m ); becomes an SSE4.1 blend: _mm_blendv_epi8.


If your vThreshold has 2 unset bits at the top, you have room to left shift without losing any of the most-significant bits. Since mulhi gives you (v0*v0) >> 16, so you can do this:

// losing the high 2 bits of v0
__m128i v0_lshift2   = _mm_slli_epi16(v0, 2);    // left by 2 before squaring
__m128i v0_sqr_asr12 = _mm_mulhi_epi16(v0_lshift2, v0_lshift2);
__m128i v1 = _mm_add_epi16(v0, I);
        v1 = _mm_add_epi16(v1, v0_sqr_asr12);

    // v1 = ((v0<<2)* (int)(v0<<2))) >> 16) + v0 + I

    // v1 = ((v0*(int)v0) >> 12) + v0 + I

Left shift by 2 before squaring is the same as left shift by 4 after squaring (of the full 32-bit result). It puts the 16 bits we want into the high 16 exactly.

But this is unusable if your v0 is so close to full range that you'd potentially overflow when left-shifting.

Otherwise, you can lose 6 low bits of v0 before multiplying

Rounding towards -Infinity with an arithmetic right shift loses 6 bits of precision, but overflow is impossible.

// losing the low 6 bits of v0
__m128i v0_asr6 = _mm_srai_epi16(v0, 6);
__m128i v0_sqr_asr12 = _mm_mullo_epi16(v0_asr6, v0_asr6);
__m128i v1 = _mm_add_epi16(v0, I);
        v1 = _mm_add_epi16(v1, v0_sqr_asr12);

    // v1 =  (v0>>6) * (int)(v0>>6)) + v0 + I

    // v1 ~= ((v0*(int)v0) >> 12) + v0 + I

I think you're losing more precision this way, so it's probably better to set vThreshold small enough that you have enough overhead to use high-half multiplies. This way includes maybe-worse rounding.

pmulhrsw to round instead of truncate might be even better, if we can set up for it as efficiently. But I don't think we can because the right-shift by 1 is an odd number. I think we'd need to make 2 separate inputs, one v0_lshift2 and one only left shifted by 1.

like image 93
Peter Cordes Avatar answered Oct 01 '22 01:10

Peter Cordes