Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Multiply-subtract in SSE

I am vectorizing a piece of code and at some point I have the following setup:

register m128 a = { 99,99,99,99,99,99,99,99 }
register m128 b = { 100,50,119,30,99,40,50,20 }

I am currently packing shorts in these registers, which is why I have 8 values per register. What I would like to do is subtract the i'th element in b with the corresponding value in a if the i'th value of b is greater than or equal to the value in a (In this case, a is filled with the constant 99 ). To this end, I first use a greater than or equal to operation between b and a, which yields, for this example:

register m128 c = { 1,0,1,0,1,0,0,0 }

To complete the operation, I'd like to use the multiply-and-subtract, i.e. to store in b the operation b -= a*c. The result would then be:

b = { 1,50,20,30,0,40,50,20 }

Is there any operation that does such thing? What I found were fused operations for Haswell, but I am currently working on Sandy-Bridge. Also, if someone has a better idea to do this, please let me know (e.g. I could do a logical subtract: if 1 in c then I subtract, nothing otherwise.

like image 580
a3mlord Avatar asked Jun 19 '15 16:06

a3mlord


2 Answers

You essentially want an SSE version of this code, right?

if (b >= a)
    t = b-a
else
    t = b
b = t

Since we want to avoid conditionals for the the SSE version so we can get rid of the control flow like this (note that the mask is inverted):

uint16_t mask = (b>=a)-1
uint16_t tmp = b-a;
uint16_t d = (b & mask) | (tmp & ~mask)
b = d

I've checked the _mm_cmpgt_epi16 intrinsic and it has a nice property in that it returns either 0x0000 for false or 0xFFFF for true, instead of a single bit 0 or 1 (thereby eliminating the need for the first subtraction). Therefore our SSE version might look like this.

__m128i mask = _mm_cmpgt_epi16 (b, a)
__m128i tmp = _mm_sub_epi16 (b, a)
__m128 d = _mm_or_ps (_mm_and_ps (mask, tmp), _mm_andnot_ps (mask, b))

EDIT: harold has mentioned a far less complicated answer. The above solution might be helpful if you need to modify the else part of the if/else.

uint16_t mask = ~( (b>=a)-1 )
uint16_t tmp = a & mask
b = b - tmp

the SSE code will be

__m128i mask = _mm_cmpgt_epi16 (b, a)
__m128i t = _mm_sub_epi16 (b, _mm_and_si128 (mask, a))
like image 161
hayesti Avatar answered Oct 06 '22 00:10

hayesti


Another alternative, if your inputs are unsigned, you can calculate

b = min(b, b-a);

This works, because if a>b then b-a wraps around and is guaranteed to result in a bigger value than b. For a<=b you will always get a value between 0 and b inclusive.

b = _mm_min_epu16(b, _mm_sub_epi16(b,a));

The required _mm_min_epu16 requires SSE4.1 or later (_mm_min_epu8 would require only SSE2).

like image 31
chtz Avatar answered Oct 05 '22 23:10

chtz