Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to sum __m256 horizontally?

I would like to horizontally sum the components of a __m256 vector using AVX instructions. In SSE I could use

_mm_hadd_ps(xmm,xmm);
_mm_hadd_ps(xmm,xmm);

to get the result at the first component of the vector, but this does not scale with the 256 bit version of the function (_mm256_hadd_ps).

What is the best way to compute the horizontal sum of a __m256 vector?

like image 512
Yoav Avatar asked Nov 04 '12 13:11

Yoav


2 Answers

This version should be optimal for both Intel Sandy/Ivy Bridge and AMD Bulldozer, and later CPUs.

// x = ( x7, x6, x5, x4, x3, x2, x1, x0 )
float sum8(__m256 x) {
    // hiQuad = ( x7, x6, x5, x4 )
    const __m128 hiQuad = _mm256_extractf128_ps(x, 1);
    // loQuad = ( x3, x2, x1, x0 )
    const __m128 loQuad = _mm256_castps256_ps128(x);
    // sumQuad = ( x3 + x7, x2 + x6, x1 + x5, x0 + x4 )
    const __m128 sumQuad = _mm_add_ps(loQuad, hiQuad);
    // loDual = ( -, -, x1 + x5, x0 + x4 )
    const __m128 loDual = sumQuad;
    // hiDual = ( -, -, x3 + x7, x2 + x6 )
    const __m128 hiDual = _mm_movehl_ps(sumQuad, sumQuad);
    // sumDual = ( -, -, x1 + x3 + x5 + x7, x0 + x2 + x4 + x6 )
    const __m128 sumDual = _mm_add_ps(loDual, hiDual);
    // lo = ( -, -, -, x0 + x2 + x4 + x6 )
    const __m128 lo = sumDual;
    // hi = ( -, -, -, x1 + x3 + x5 + x7 )
    const __m128 hi = _mm_shuffle_ps(sumDual, sumDual, 0x1);
    // sum = ( -, -, -, x0 + x1 + x2 + x3 + x4 + x5 + x6 + x7 )
    const __m128 sum = _mm_add_ss(lo, hi);
    return _mm_cvtss_f32(sum);
}

haddps is not efficient on any CPU; the best you can do is one shuffle (to extract the high half) and one add, repeat until one element left. Narrowing to 128-bit as the first step benefits AMD before Zen2, and is not a bad thing anywhere.

See Fastest way to do horizontal SSE vector sum on x86 for more details about efficiency.

like image 54
Marat Dukhan Avatar answered Nov 07 '22 11:11

Marat Dukhan


This can be done with the following code:

ymm2 = _mm256_permute2f128_ps(ymm , ymm , 1);
ymm = _mm256_add_ps(ymm, ymm2);
ymm = _mm256_hadd_ps(ymm, ymm);
ymm = _mm256_hadd_ps(ymm, ymm);

but there might be a better solution.

like image 22
Yoav Avatar answered Nov 07 '22 10:11

Yoav