I want to implement a simple function in SSE (a program like Izhikevich spiking neuron model). It should work with 16 bit signed integers (8.8 fixed point) and it needs to check the overflow condition during some integration step, and set a SSE mask (if overflow occured):
// initialized like following:
short I = 0x1BAD; // current injected to neuron
short vR = 0xF00D; // some reset threshold when spiked (negative)
// step to be vectorized:
short v0 = vReset;
for(;;) {
// v0*v0/16 likely overflows => use 32 bit (16.16)
short v0_sqr = ((int)v0)*((int)v0) / (1<<(8+4)); // not sure how "(v0*v0)>>(8+4)" would affect sign..
// or ((int)v0)*((int)v0) >> (8+4); // arithmetic right shift
// original paper used v' = (v0^2)/25 + ...
short v1 = v0_sqr + v0 + I;
int m; // mask is set when neuron fires
if(v1_overflows_during_this_operation()) { // "v1 > 0x7FFF" - way to detect?
m=0xFFFFFFFF;
else
m=0;
v0 = ( v1 & ~m ) | (vR & m );
}
But I haven't found the _mm_mul_epi16()
instruction, to check high word of the multiplication. Why, and how such task v1_overflows_during_this_operation()
is supposed to be implemented in SSE?
Unlike 32x32 => 64, there is no widening 16x16 -> 32 SSE multiplication instruction.
Instead, there's _mm_mulhi_epi16
and _mm_mulhi_epu16
which give you just the signed or unsigned upper half of the full result.
(and _mm_mullo_epi16
, which does packed 16x16 => 16-bit low half truncating multiply, which is the same for signed or unsigned).
You could use _mm_unpacklo/hi_epi16
to interleave low/high halves into a pair of vectors with 32-bit elements, but that would be pretty slow. But yes, you could _mm_srai_epi32(v, 8+4)
arithmetic right-shift that by 12 and then re-pack, maybe with _mm_packs_epi32
(signed saturation back to 16-bit). Then I guess check for saturation?
Your use case is unusual. There's _mm_mulhrs_epi16
which gives you the high 17 bits, rounded off and then truncated to 16 bits. (See the description). That's useful for some fixed-point algorithms where inputs are scaled to put the result in the upper half, and you want to round including the low half instead of truncate.
You might actually use _mm_mulhrs_epi16
or _mm_mulhi_epi16
as your best bet for keeping the most precision, maybe by left-shifting your v0
before squaring to just the point where the high half will give you (v0*v0) >> (8+4)
.
So do you think it is easier not to allow result to overflow, and just to generate mask with
_mm_cmpge_epi16(v1, vThreshold)
as author does in the original paper?
Hell yes! gaining another bit or two of precision would cost you maybe a factor of 2 in performance, because you'd have to compute another multiply result to check for overflow, or effectively widen to 32-bit (cutting the number of elements per vector in half), as described above.
With a compare result, v0 = ( v1 & ~m ) | (vR & m );
becomes an SSE4.1 blend: _mm_blendv_epi8
.
If your vThreshold
has 2 unset bits at the top, you have room to left shift without losing any of the most-significant bits. Since mulhi
gives you (v0*v0) >> 16
, so you can do this:
// losing the high 2 bits of v0
__m128i v0_lshift2 = _mm_slli_epi16(v0, 2); // left by 2 before squaring
__m128i v0_sqr_asr12 = _mm_mulhi_epi16(v0_lshift2, v0_lshift2);
__m128i v1 = _mm_add_epi16(v0, I);
v1 = _mm_add_epi16(v1, v0_sqr_asr12);
// v1 = ((v0<<2)* (int)(v0<<2))) >> 16) + v0 + I
// v1 = ((v0*(int)v0) >> 12) + v0 + I
Left shift by 2 before squaring is the same as left shift by 4 after squaring (of the full 32-bit result). It puts the 16 bits we want into the high 16 exactly.
But this is unusable if your v0
is so close to full range that you'd potentially overflow when left-shifting.
Otherwise, you can lose 6 low bits of v0
before multiplying
Rounding towards -Infinity with an arithmetic right shift loses 6 bits of precision, but overflow is impossible.
// losing the low 6 bits of v0
__m128i v0_asr6 = _mm_srai_epi16(v0, 6);
__m128i v0_sqr_asr12 = _mm_mullo_epi16(v0_asr6, v0_asr6);
__m128i v1 = _mm_add_epi16(v0, I);
v1 = _mm_add_epi16(v1, v0_sqr_asr12);
// v1 = (v0>>6) * (int)(v0>>6)) + v0 + I
// v1 ~= ((v0*(int)v0) >> 12) + v0 + I
I think you're losing more precision this way, so it's probably better to set vThreshold
small enough that you have enough overhead to use high-half multiplies. This way includes maybe-worse rounding.
pmulhrsw
to round instead of truncate might be even better, if we can set up for it as efficiently. But I don't think we can because the right-shift by 1 is an odd number. I think we'd need to make 2 separate inputs, one v0_lshift2
and one only left shifted by 1.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With