Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to do horizontal vector sum with AVX instructions [duplicate]

Tags:

I have a packed vector of four 64-bit floating-point values.
I would like to get the sum of the vector's elements.

With SSE (and using 32-bit floats) I could just do the following:

v_sum = _mm_hadd_ps(v_sum, v_sum);
v_sum = _mm_hadd_ps(v_sum, v_sum);

Unfortunately, even though AVX features a _mm256_hadd_pd instruction, it differs in the result from the SSE version. I believe this is due to the fact that most AVX instructions work as SSE instructions for each low and high 128-bits separately, without ever crossing the 128-bit boundary.

Ideally, the solution I am looking for should follow these guidelines:
1) only use AVX/AVX2 instructions. (no SSE)
2) do it in no more than 2-3 instructions.

However, any efficient/elegant way to do it (even without following the above guidelines) is always well accepted.

Thanks a lot for any help.

-Luigi Castelli

like image 284
Luigi Castelli Avatar asked Mar 19 '12 18:03

Luigi Castelli


2 Answers

If you have two __m256d vectors x1 and x2 that each contain four doubles that you want to horizontally sum, you could do:

__m256d x1, x2;
// calculate 4 two-element horizontal sums:
// lower 64 bits contain x1[0] + x1[1]
// next 64 bits contain x2[0] + x2[1]
// next 64 bits contain x1[2] + x1[3]
// next 64 bits contain x2[2] + x2[3]
__m256d sum = _mm256_hadd_pd(x1, x2);
// extract upper 128 bits of result
__m128d sum_high = _mm256_extractf128_pd(sum1, 1);
// add upper 128 bits of sum to its lower 128 bits
__m128d result = _mm_add_pd(sum_high, _mm256_castpd256_pd128(sum));
// lower 64 bits of result contain the sum of x1[0], x1[1], x1[2], x1[3]
// upper 64 bits of result contain the sum of x2[0], x2[1], x2[2], x2[3]

So it looks like 3 instructions will do 2 of the horizontal sums that you need. The above is untested, but you should get the concept.

like image 65
Jason R Avatar answered Sep 25 '22 08:09

Jason R


If you want just the sum, and a bit of scalar code is acceptable:

__m256d x;
__m256d s = _mm256_hadd_pd(x,x);
return ((double*)&s)[0] + ((double*)&s)[2];
like image 44
RJVB Avatar answered Sep 23 '22 08:09

RJVB