Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to optimise my AVX Code

I tried to translate the following code into AVX intrinsics in order to improve the performance:

for (int alpha = 0; alpha < 4; alpha++) {
    for (int k = 0; k < 3; k++) {
        for (int beta = 0; beta < 4; beta++) {
            for (int l = 0; l < 4 ; l++) {
                d2_phi[(alpha*3+k)*16 + beta*4+l] =
                    -   (d2_phi[(alpha*3+k)*16 + beta*dim+l]

                        +   b[k] * (  lam_12[ beta][alpha] *   a[l] 
                                    + lam_22[alpha][ beta] *   b[l] 
                                    + lam_23[alpha][ beta] * rjk[l]  )

                        + rjk[k] * (  lam_13[ beta][alpha] *   a[l] 
                                    + lam_23[ beta][alpha] *   b[l] 
                                    + lam_33[alpha][ beta] * rjk[l]  )
                        ) / sqrt_gamma;
            }
        }
    }
}

and tried this the following way:

// load sqrt_gamma, because it is constant
__m256d ymm7 = _mm256_broadcast_sd(&sqrt_gamma);        

for (int alpha=0; alpha < 4; alpha++) {
    for (int k=0; k < 3; k++) {
        // Load values that are only dependent on k
        __m256d ymm9 = _mm256_broadcast_sd(b+k);   // all   b[k]
        __m256d ymm8 = _mm256_broadcast_sd(rjk+k); // all rjk[k]

        for (int beta=0; beta < 4; beta++) {
            // Load the lambdas, because they will stay the same for nine iterations
            __m256d ymm15 = _mm256_broadcast_sd(lam_12_p + 4*beta + alpha);   // all lam_12[ beta][alpha]
            __m256d ymm14 = _mm256_broadcast_sd(lam_22_p + 4*alpha + beta);   // all lam_22[alpha][ beta]
            __m256d ymm13 = _mm256_broadcast_sd(lam_23_p + 4*alpha + beta);   // all lam_23[alpha][ beta]
            __m256d ymm12 = _mm256_broadcast_sd(lam_13_p + 4*beta + alpha);   // all lam_13[ beta][alpha]
            __m256d ymm11 = _mm256_broadcast_sd(lam_23_p + 4*beta + alpha);   // all lam_23[ beta][alpha]
            __m256d ymm10 = _mm256_broadcast_sd(lam_33_p + 4*alpha + beta);   //     lam_33[alpha][ beta]   

            // Load the values that depend on the innermost loop, which is removed do to AVX
            __m256d ymm6 =_mm256_load_pd(a);   //   a[i] until   a[l+3]
            __m256d ymm5 =_mm256_load_pd(b);   //   b[i] until   b[l+3]
            __m256d ymm4 =_mm256_load_pd(rjk); // rjk[i] until rjk[l+3]
            //__m256d ymm3 =_mm256_load_pd(d2_phi_p + (alpha*3+k)*16  + beta*dim); // d2_phi[(alpha*3+k)*12 + beta*dim] until d2_phi[(alpha*3+k)*12 + beta*dim +3]
            __m256d ymm3 =_mm256_load_pd(d2_phi_p + 4*s);
            // Block that is later on multiplied with b[k]
            __m256d ymm2 = _mm256_mul_pd(ymm15, ymm6); // lam_12[ beta][alpha] * a[l]
            __m256d ymm1 = _mm256_mul_pd(ymm14, ymm5); // lam_22[alpha][ beta] * b[l];

            __m256d ymm0 = _mm256_add_pd(ymm2, ymm1);  // lam_12[ beta][alpha] * a[l] + lam_22[alpha][ beta]*b[l];

            ymm2 = _mm256_mul_pd(ymm13, ymm4);         // lam_23[alpha][ beta] * rjk[l]
            ymm0 = _mm256_add_pd(ymm2, ymm0);          // lam_12[ beta][alpha] * a[l] + lam_22[alpha][ beta]*b[l] + lam_23[alpha][ beta] * b[i];

            ymm0 = _mm256_mul_pd(ymm9, ymm0);          // b[k] * (first sum of three)


            // Block that is later on multiplied with rjk[k]
            ymm2 = _mm256_mul_pd(ymm12, ymm6); // lam_13[ beta][alpha] *  a[l]
            ymm1 = _mm256_mul_pd(ymm11, ymm5); // lam_23[ beta][alpha] *  b[l]

            ymm2 = _mm256_add_pd(ymm2, ymm1);  // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l];

            ymm1 = _mm256_mul_pd(ymm10, ymm4); // lam_33[alpha][ beta] * rjk[l]
            ymm2 = _mm256_add_pd(ymm2, ymm1);  // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l] + lam_33[alpha][ beta] *rjk[l]

            ymm2 = _mm256_mul_pd(ymm2, ymm8);  // rjk[k] * (second sum of three)
            ymm0 = _mm256_add_pd(ymm0, ymm2);  // add to temporal result in ymm0
            ymm0 = _mm256_add_pd(ymm3, ymm0);  // Old value of d2 Phi;

            ymm0 = _mm256_div_pd(ymm0, ymm7);   // all divided by sqrt_gamma

            _mm256_store_pd(d2_phi_p + (alpha*3+k)*16  + beta*dim, ymm0);
        }
    }
}

But the peformance is bad. It is even slower than auto-vectorized code generated by Intel compiler. I tried the following things:

  • All data arrays are 64-byte aligned by __declspec(align(64))
  • The store at the end was replaced by a streaming store _mm256_stream_pd

When I look into the created assembly code, I see, that the auto-code fetches all parameters every iteration (and not as I did, only in the loops they belong to). It also contains more arithmetic operations. As a last point the store at the end only need half of the time of mine (I repeat the code fragment 1000000 times) and I don't see a reason for that. (I used the Intel VTune Amplifier to look at the assembly and the spent time.)

Thanks for all help in advance!

like image 774
user3572032 Avatar asked Feb 11 '23 15:02

user3572032


1 Answers

I'm putting this here as a second answer, as it is a different and much more extensive optimisation. The key is to change the order of the loops to reduce the number of redundant operations by hoisting many of the loads and arithmetic operations out of the innermost loop.

Original loop structure:

for (int alpha=0; alpha < 4; alpha++) {
    for (int k=0; k < 3; k++) {
        for (int beta=0; beta < 4; beta++) {
            for (int l=0; l < 4 ; l++) {

New loop structure:

for (int alpha=0; alpha < 4; alpha++) {
    for (int beta=0; beta < 4; beta++) {
        for (int k=0; k < 3; k++) {
            for (int l=0; l < 4 ; l++) {

Complete tested and optimised implementation of your function:

static void foo(
    double lam_11[4][4],
    double lam_12[4][4],
    double lam_13[4][4],
    double lam_22[4][4],
    double lam_23[4][4],
    double lam_33[4][4],
    const double rjk[4],
    const double a[4],
    const double b[4],
    const double sqrt_gamma,
    const double SPab,
    const double d1_phi[16],
    double d2_phi[192])
{
    const double re_sqrt_gamma = 1.0 / sqrt_gamma;

    memset(d2_phi, 0.0, 192*sizeof(double));

    const __m256d ymm6 = _mm256_load_pd(a); // load the whole 4-vector 'a' into register

    {
        // load SPab, because it is constant
        const __m256d ymm0 = _mm256_broadcast_sd(&SPab);
        const __m256d ymm7 = _mm256_load_pd(b); // load the whole 4-vector 'b' into register
        const __m256d ymm8 = _mm256_load_pd(rjk); // load the whole 4-vector 'rjk' into register

        for (int alpha=0; alpha < 4; alpha++)
        {
            for (int beta=0; beta < 4; beta++)
            {
                // Load the three lambdas to all
                const __m256d ymm3 = _mm256_broadcast_sd(&lam_11[alpha][beta]);
                const __m256d ymm4 = _mm256_broadcast_sd(&lam_12[alpha][beta]);
                const __m256d ymm5 = _mm256_broadcast_sd(&lam_13[alpha][beta]);

                const __m256d ymm9 = _mm256_load_pd(d1_phi + beta*4);

                // Do the three Multiplications
                const __m256d ymm13 = _mm256_mul_pd(ymm4,ymm7); // lam_12[alpha][ beta] *  b[l] = PROD2
                const __m256d ymm14 = _mm256_mul_pd(ymm5,ymm8); // lam_13[alpha][ beta] * rjk[l] = PROD3
                const __m256d ymm15 = _mm256_mul_pd(ymm3,ymm6); // lam_11[alpha][ beta] *  a[l] = PROD1
                __m256d ymm12 = _mm256_add_pd(ymm15, ymm13); // PROD1 + PROD2 = PROD12
                ymm12 = _mm256_add_pd(ymm12, ymm14); // PROD12 + PROD3 = PROD123

                double* addr = d2_phi + alpha*3*16  + beta*dim;

                for (int k=0; k < 3; k++)
                {
                    const __m256d ymm1 = _mm256_broadcast_sd(&d1_phi[alpha*dim + k]); // load d1_phi[alpha*dim+k] to all
                    const __m256d ymm2 = _mm256_broadcast_sd(&a[k]); // load a[k] to all
                    const __m256d ymm10 = _mm256_mul_pd(ymm0, ymm1); // SPab * d1_phi[alpha*dim+k] = PRE
                    const __m256d ymm11 = _mm256_mul_pd(ymm10, ymm9); // PRE * d1_phi[beta*dim+l] = SUM1

                    __m256d ymm12t = _mm256_mul_pd(ymm12, ymm2); // a[k] * PROD123 = SUM2
                    ymm12t = _mm256_add_pd(ymm11, ymm12t); // SUM1 + SUM2

                    _mm256_store_pd(addr, ymm12t);

                    addr+=16;
                }
            }
        }
    }

    {
        const __m256d ymm4 =_mm256_load_pd(rjk); // rjk[i] until rjk[l+3]
        const __m256d ymm5 =_mm256_load_pd(b); // b[l] until b[l+3]

        // load sqrt_gamma, because it is constant
        const __m256d ymm7 = _mm256_broadcast_sd(&re_sqrt_gamma);

        for (int alpha=0; alpha < 4; alpha++)
        {
            for (int beta=0; beta < 4; beta++)
            {
                // Load the lambdas, because they will stay the same for nine iterations
                const __m256d ymm15 = _mm256_broadcast_sd(&lam_12[beta][alpha]);   // all lam_12[ beta][alpha]
                const __m256d ymm14 = _mm256_broadcast_sd(&lam_22[alpha][beta]);   // all lam_22[alpha][ beta]
                const __m256d ymm13 = _mm256_broadcast_sd(&lam_23[alpha][beta]);   // all lam_23[alpha][ beta]
                const __m256d ymm12 = _mm256_broadcast_sd(&lam_13[beta][alpha]);   // all lam_13[ beta][alpha]
                const __m256d ymm11 = _mm256_broadcast_sd(&lam_23[beta][alpha]); // all lam_23[ beta][alpha]
                const __m256d ymm10 = _mm256_broadcast_sd(&lam_33[alpha][beta]); // lam_33[alpha][ beta]

                __m256d ymm0, ymm1, ymm2;

                // Block that is later on multiplied with b[k]
                ymm2 = _mm256_mul_pd(ymm15 , ymm6); // lam_12[ beta][alpha] *  a[l]
                ymm1 = _mm256_mul_pd(ymm14 , ymm5); // lam_22[alpha][ beta] * b[l];
                ymm0 = _mm256_add_pd(ymm2, ymm1);   // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l];
                ymm2 = _mm256_mul_pd(ymm13 , ymm4); // lam_23[alpha][ beta] * rjk[l]
                ymm0 = _mm256_add_pd(ymm2, ymm0);   // lam_12[ beta][alpha]* a[l] + lam_22[alpha][ beta]*b[l] + lam_23[alpha][ beta] * b[i];

                // Block that is later on multiplied with rjk[k]
                ymm2 = _mm256_mul_pd(ymm12 , ymm6); // lam_13[ beta][alpha] *  a[l]
                ymm1 = _mm256_mul_pd(ymm11 , ymm5); // lam_23[ beta][alpha] *  b[l]
                ymm2 = _mm256_add_pd(ymm2, ymm1);   // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l];
                ymm1 = _mm256_mul_pd(ymm10 , ymm4); // lam_33[alpha][ beta] * rjk[l]
                ymm2 = _mm256_add_pd(ymm2 , ymm1);  // lam_13[ beta][alpha] *  a[l] + lam_22[alpha][ beta]*b[l] + lam_33[alpha][ beta] *rjk[l]

                double* addr = d2_phi + alpha*3*16  + beta*dim;

                for (int k=0; k < 3; k++)
                {
                    // Load values that are only dependent on k
                    const __m256d ymm9 = _mm256_broadcast_sd(b+k); // all b[k]
                    const __m256d ymm8 = _mm256_broadcast_sd(rjk+k); // all rjk[k]

                    // Load the values that depend on the innermost loop, which is removed do to AVX

                    const __m256d ymm3 =_mm256_load_pd(addr);

                    __m256d ymm0t, ymm1t, ymm2t;

                    // Block that is later on multiplied with b[k]

                    ymm0t = _mm256_mul_pd(ymm9 , ymm0); // b[k] * (first sum of three)

                    // Block that is later on multiplied with rjk[k]

                    ymm1t = _mm256_mul_pd(ymm2 , ymm8); // rjk[k] * (second sum of three)
                    ymm2t = _mm256_add_pd(ymm0t, ymm1t); // add to temporal result in ymm0
                    ymm1t = _mm256_add_pd(ymm3, ymm2t);  // Old value of d2 Phi;

                    ymm2t = _mm256_mul_pd(ymm1t, ymm7); // all divided by sqrt_gamma
                    ymm1t = _mm256_xor_pd(ymm2t, SIGNMASK);

                    _mm256_store_pd(addr, ymm1t);

                    addr += 16;
                }
            }
        }
    }
}

The original AVX code ran at around 500 ms with your test harness, the new version runs at around 200 ms, so that's a 2.5x throughput improvement.

Updated version of your test harness with original code and optimised code here: http://pastebin.com/yMPbYPjb

like image 176
Paul R Avatar answered Feb 15 '23 13:02

Paul R