Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fastest way to perform AVX inner product operations with mixed (float, double) input vectors

I need to build a single-precision floating-point inner product routine for mixed single/double-precision floating-point vectors, exploiting the AVX instruction set for SIMD registers with 256 bits.

Problem: one input vector is float (x), while the other is double (yD).

Hence, before to compute the true inner product operations, I need to convert my input yD vector data from double to float.

Using the SSE2 instruction set, I was able to implement a very fast code doing what I needed, and with speed performances very close to the case when both vectors x and y were float:

  void vector_operation(const size_t i) 
  {
    __m128 X = _mm_load_ps(x + i);
    __m128 Y = _mm_movelh_ps(_mm_cvtpd_ps(_mm_load_pd(yD + i + 0)), _mm_cvtpd_ps(_mm_load_pd(yD + i + 2)));
    //inner-products accumulation
    res = _mm_add_ps(res, _mm_mul_ps(X, Y));
  }   

Now, with the hope to further speed-up, I implemented a correpsonding version with AVX instruction set:

  inline void vector_operation(const size_t i) 
  {
    __m256 X = _mm256_load_ps(x + i);
    __m128 yD1 = _mm_cvtpd_ps(_mm_load_pd(yD + i + 0));
    __m128 yD2 = _mm_cvtpd_ps(_mm_load_pd(yD + i + 2));
    __m128 yD3 = _mm_cvtpd_ps(_mm_load_pd(yD + i + 4));
    __m128 yD4 = _mm_cvtpd_ps(_mm_load_pd(yD + i + 6));

    __m128 Ylow = _mm_movelh_ps(yD1, yD2);
    __m128 Yhigh = _mm_movelh_ps(yD3, yD4);

    //Pack __m128 data inside __m256 
    __m256 Y = _mm256_permute2f128_ps(_mm256_castps128_ps256(Ylow), _mm256_castps128_ps256(Yhigh), 0x20);

    //inner-products accumulation 
    res = _mm256_add_ps(res, _mm256_mul_ps(X, Y));
  }

I also tested other AVX implementations using, for example, casting and insertion operations instead of perfmuting data. Performances were comparably poor compared to the case where both x and y vectors were float.

The problem with the AVX code is that no matter how I implemented it, its performance is by far inferior to the ones achieved by using only float x and y vectors (i.e. no double-float conversion is needed).

The conversion from double to float for the yD vector seems pretty fast, while a lot of time is lost in the line where data is inserted in the _m256 Y register.

Do you know if this is a well-known issue with AVX?

Do you have a solution that could preserve good performances?

Thanks in advance!

like image 335
Liotro78 Avatar asked Mar 21 '18 18:03

Liotro78


1 Answers

I rewrote your function and took better advantage of what AVX has to offer. I also used fused multiply-add at the end; if you can't use FMA, just replace that line with addition and multiplication. I only now see that I wrote an implementation that uses unaligned loads and yours uses aligned loads, but I'm not gonna lose any sleep over it. :)

__m256 foo(float*x, double* yD, const size_t i, __m256 res_prev)
{
  __m256 X = _mm256_loadu_ps(x + i);

  __m128 yD21 = _mm256_cvtpd_ps(_mm256_loadu_pd(yD + i + 0));
  __m128 yD43 = _mm256_cvtpd_ps(_mm256_loadu_pd(yD + i + 4));

  __m256 Y = _mm256_set_m128(yD43, yD21);

  return _mm256_fmadd_ps(X, Y, res_prev);
}

I did a quick benhmark and compared running times of your and my implementation. I tried two different benchmark approaches with several repetitions and every time my code was around 15% faster. I used MSVC 14.1 compiler and compiled the program with /O2 and /arch:AVX2 flags.

EDIT: this is the disassembly of the function:

vcvtpd2ps   xmm3,ymmword ptr [rdx+r8*8+20h]  
vcvtpd2ps   xmm2,ymmword ptr [rdx+r8*8]  
vmovups     ymm0,ymmword ptr [rcx+r8*4]  

vinsertf128 ymm3,ymm2,xmm3,1  

vfmadd213ps ymm0,ymm3,ymmword ptr [r9] 

EDIT 2: this is the disassembly of your AVX implementation of the same algorithm:

vcvtpd2ps   xmm0,xmmword ptr [rdx+r8*8+30h]  
vcvtpd2ps   xmm1,xmmword ptr [rdx+r8*8+20h]  

vmovlhps    xmm3,xmm1,xmm0  
vcvtpd2ps   xmm0,xmmword ptr [rdx+r8*8+10h]  
vcvtpd2ps   xmm1,xmmword ptr [rdx+r8*8]  
vmovlhps    xmm2,xmm1,xmm0  

vperm2f128  ymm3,ymm2,ymm3,20h  

vmulps      ymm0,ymm3,ymmword ptr [rcx+r8*4]  
vaddps      ymm0,ymm0,ymmword ptr [r9]
like image 71
Nejc Avatar answered Oct 20 '22 03:10

Nejc