Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Fast AVX512 modulo when same divisor

I have tried to find divisors to potential factorial primes (number of the form n!+-1) and because I recently bought Skylake-X workstation I thought that I could get some speed up using AVX512 instructions.

Algorithm is simple and main step is to take modulo repeatedly respect to same divisor. Main thing is to loop over large range of n values. Here is naïve approach written in c (P is table of primes):

uint64_t factorial_naive(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i, residue;
for (i = 0; i < APP_BUFLEN; i++){
    residue = 2;
    for (n=3; n <= nmax; n++){
        residue *=  n;
        residue %= P[i];
        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1){
                report_factor(n, -1, P[i]);
            }
            if(residue == P[i]- 1){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

Here the idea is to check a large range of n, e.g. 1,000,000 -> 10,000,000 against the same set of divisors. So we will take modulo respect to same divisor several million times. using DIV is very slow so there are several possible approaches depending on the range of the calculations. Here in my case n is most likely less than 10^7 and potential divisor p is less than 10,000 G (< 10^13), So numbers are less than 64-bits and also less than 53-bits!, but the product of the maximum residue (p-1) times n is larger than 64-bits. So I thought that simplest version of Montgomery method doesn’t work because we are taking modulo from number that is larger than 64-bit.

I found some old code for power pc where FMA was used to get an accurate product up to 106 bits (I guess) when using doubles. So I converted this approach to AVX 512 assembler (Intel Intrinsics). Here is a simple version of the FMA method, this is based on work of Dekker (1971), Dekker product and FMA version of TwoProduct of that are useful words when trying to find/googling rationale behind this. Also this approach has been discussed in this forum (e.g. here).

int64_t factorial_FMA(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
uint64_t n, i;
double prime_double, prime_double_reciprocal, quotient, residue;
double nr, n_double, prime_times_quotient_high, prime_times_quotient_low;

for (i = 0; i < APP_BUFLEN; i++){
    residue = 2.0;
    prime_double = (double)P[i];
    prime_double_reciprocal = 1.0 / prime_double;
    n_double = 3.0;
    for (n=3; n <= nmax; n++){
        nr =  n_double * residue;
        quotient = fma(nr, prime_double_reciprocal, rounding_constant);
        quotient -= rounding_constant;
        prime_times_quotient_high= prime_double * quotient;
        prime_times_quotient_low = fma(prime_double, quotient, -prime_times_quotient_high);
        residue = fma(residue, n, -prime_times_quotient_high) - prime_times_quotient_low;

        if (residue < 0.0) residue += prime_double;
        n_double += 1.0;

        // Lets check if we found factor
        if (nmin <= n){
            if( residue == 1.0){
                report_factor(n, -1, P[i]);
            }
            if(residue == prime_double - 1.0){
                report_factor(n, 1, P[i]);
            }
        }
    }
}
return EXIT_SUCCESS;
}

Here I have used magic constant

static const double rounding_constant = 6755399441055744.0; 

that is 2^51 + 2^52 magic number for doubles.

I converted this to AVX512 (32 potential divisors per loop) and analyzed result using IACA. It told that Throughput Bottleneck: Backend and Backend allocation was stalled due to unavailable allocation resources. I am not very experienced with assembler so my question is that is there anything I can do to speed this up and solve this backend bottleneck?

AVX512 code is here and can be found also from github

uint64_t factorial_AVX512_unrolled_four(uint64_t const nmin, uint64_t const nmax, const uint64_t *restrict P)
{
// we are trying to find a factor for a factorial numbers : n! +-1
//nmin is minimum n we want to report and nmax is maximum. P is table of primes
// we process 32 primes in one loop.
// naive version of the algorithm is int he function factorial_naive
// and simple version of the FMA based approach in the function factorial_simpleFMA

const double one_table[8] __attribute__ ((aligned(64))) ={1.0, 1.0, 1.0,1.0,1.0,1.0,1.0,1.0};


uint64_t n;

__m512d zero, rounding_const, one, n_double;

__m512i prime1, prime2, prime3, prime4;

__m512d residue1, residue2, residue3, residue4;
__m512d prime_double_reciprocal1, prime_double_reciprocal2, prime_double_reciprocal3, prime_double_reciprocal4;
__m512d quotient1, quotient2, quotient3, quotient4;
__m512d prime_times_quotient_high1, prime_times_quotient_high2, prime_times_quotient_high3, prime_times_quotient_high4;
__m512d prime_times_quotient_low1, prime_times_quotient_low2, prime_times_quotient_low3, prime_times_quotient_low4;
__m512d nr1, nr2, nr3, nr4;
__m512d prime_double1, prime_double2, prime_double3, prime_double4;
__m512d prime_minus_one1, prime_minus_one2, prime_minus_one3, prime_minus_one4;

__mmask8 negative_reminder_mask1, negative_reminder_mask2, negative_reminder_mask3, negative_reminder_mask4;
__mmask8 found_factor_mask11, found_factor_mask12, found_factor_mask13, found_factor_mask14;
__mmask8 found_factor_mask21, found_factor_mask22, found_factor_mask23, found_factor_mask24;

// load data and initialize cariables for loop
rounding_const = _mm512_set1_pd(rounding_constant);
one = _mm512_load_pd(one_table);
zero = _mm512_setzero_pd ();

// load primes used to sieve
prime1 = _mm512_load_epi64((__m512i *) &P[0]);
prime2 = _mm512_load_epi64((__m512i *) &P[8]);
prime3 = _mm512_load_epi64((__m512i *) &P[16]);
prime4 = _mm512_load_epi64((__m512i *) &P[24]);

// convert primes to double
prime_double1 = _mm512_cvtepi64_pd (prime1); // vcvtqq2pd
prime_double2 = _mm512_cvtepi64_pd (prime2); // vcvtqq2pd
prime_double3 = _mm512_cvtepi64_pd (prime3); // vcvtqq2pd
prime_double4 = _mm512_cvtepi64_pd (prime4); // vcvtqq2pd

// calculates 1.0/ prime
prime_double_reciprocal1 = _mm512_div_pd(one, prime_double1);
prime_double_reciprocal2 = _mm512_div_pd(one, prime_double2);
prime_double_reciprocal3 = _mm512_div_pd(one, prime_double3);
prime_double_reciprocal4 = _mm512_div_pd(one, prime_double4);

// for comparison if we have found factors for n!+1
prime_minus_one1 = _mm512_sub_pd(prime_double1, one);
prime_minus_one2 = _mm512_sub_pd(prime_double2, one);
prime_minus_one3 = _mm512_sub_pd(prime_double3, one);
prime_minus_one4 = _mm512_sub_pd(prime_double4, one);

// residue init
residue1 =  _mm512_set1_pd(2.0);
residue2 =  _mm512_set1_pd(2.0);
residue3 =  _mm512_set1_pd(2.0);
residue4 =  _mm512_set1_pd(2.0);

// double counter init
n_double = _mm512_set1_pd(3.0);

// main loop starts here. typical value for nmax can be 5,000,000 -> 10,000,000

for (n=3; n<=nmax; n++) // main loop
{

    // timings for instructions:
    // _mm512_load_epi64 = vmovdqa64 : L 1, T 0.5
    // _mm512_load_pd = vmovapd : L 1, T 0.5
    // _mm512_set1_pd
    // _mm512_div_pd = vdivpd : L 23, T 16
    // _mm512_cvtepi64_pd = vcvtqq2pd : L 4, T 0,5

    // _mm512_mul_pd = vmulpd :  L 4, T 0.5
    // _mm512_fmadd_pd = vfmadd132pd, vfmadd213pd, vfmadd231pd :  L 4, T 0.5
    // _mm512_fmsub_pd = vfmsub132pd, vfmsub213pd, vfmsub231pd : L 4, T 0.5
    // _mm512_sub_pd = vsubpd : L 4, T 0.5
    // _mm512_cmplt_pd_mask = vcmppd : L ?, Y 1
    // _mm512_mask_add_pd = vaddpd :  L 4, T 0.5
    // _mm512_cmpeq_pd_mask = vcmppd L ?, Y 1
    // _mm512_kor = korw L 1, T 1

    // nr = residue *  n
    nr1 = _mm512_mul_pd (residue1, n_double);
    nr2 = _mm512_mul_pd (residue2, n_double);
    nr3 = _mm512_mul_pd (residue3, n_double);
    nr4 = _mm512_mul_pd (residue4, n_double);

    // quotient = nr * 1.0/ prime_double + rounding_constant
    quotient1 = _mm512_fmadd_pd(nr1, prime_double_reciprocal1, rounding_const);
    quotient2 = _mm512_fmadd_pd(nr2, prime_double_reciprocal2, rounding_const);
    quotient3 = _mm512_fmadd_pd(nr3, prime_double_reciprocal3, rounding_const);
    quotient4 = _mm512_fmadd_pd(nr4, prime_double_reciprocal4, rounding_const);

    // quotient -= rounding_constant, now quotient is rounded to integer
    // countient should be at maximum nmax (10,000,000)
    quotient1 = _mm512_sub_pd(quotient1, rounding_const);
    quotient2 = _mm512_sub_pd(quotient2, rounding_const);
    quotient3 = _mm512_sub_pd(quotient3, rounding_const);
    quotient4 = _mm512_sub_pd(quotient4, rounding_const);

    // now we calculate high and low for prime * quotient using decker product (FMA).
    // quotient is calculated using approximation but this is accurate for given quotient
    prime_times_quotient_high1 = _mm512_mul_pd(quotient1, prime_double1);
    prime_times_quotient_high2 = _mm512_mul_pd(quotient2, prime_double2);
    prime_times_quotient_high3 = _mm512_mul_pd(quotient3, prime_double3);
    prime_times_quotient_high4 = _mm512_mul_pd(quotient4, prime_double4);


    prime_times_quotient_low1 = _mm512_fmsub_pd(quotient1, prime_double1, prime_times_quotient_high1);
    prime_times_quotient_low2 = _mm512_fmsub_pd(quotient2, prime_double2, prime_times_quotient_high2);
    prime_times_quotient_low3 = _mm512_fmsub_pd(quotient3, prime_double3, prime_times_quotient_high3);
    prime_times_quotient_low4 = _mm512_fmsub_pd(quotient4, prime_double4, prime_times_quotient_high4);

    // now we calculate new reminder using decker product and using original values
    // we subtract above calculated prime * quotient (quotient is aproximation)

    residue1 = _mm512_fmsub_pd(residue1, n_double, prime_times_quotient_high1);
    residue2 = _mm512_fmsub_pd(residue2, n_double, prime_times_quotient_high2);
    residue3 = _mm512_fmsub_pd(residue3, n_double, prime_times_quotient_high3);
    residue4 = _mm512_fmsub_pd(residue4, n_double, prime_times_quotient_high4);

    residue1 = _mm512_sub_pd(residue1, prime_times_quotient_low1);
    residue2 = _mm512_sub_pd(residue2, prime_times_quotient_low2);
    residue3 = _mm512_sub_pd(residue3, prime_times_quotient_low3);
    residue4 = _mm512_sub_pd(residue4, prime_times_quotient_low4);

    // lets check if reminder < 0
    negative_reminder_mask1 = _mm512_cmplt_pd_mask(residue1,zero);
    negative_reminder_mask2 = _mm512_cmplt_pd_mask(residue2,zero);
    negative_reminder_mask3 = _mm512_cmplt_pd_mask(residue3,zero);
    negative_reminder_mask4 = _mm512_cmplt_pd_mask(residue4,zero);

    // we and prime back to reminder using mask if it was < 0
    residue1 = _mm512_mask_add_pd(residue1, negative_reminder_mask1, residue1, prime_double1);
    residue2 = _mm512_mask_add_pd(residue2, negative_reminder_mask2, residue2, prime_double2);
    residue3 = _mm512_mask_add_pd(residue3, negative_reminder_mask3, residue3, prime_double3);
    residue4 = _mm512_mask_add_pd(residue4, negative_reminder_mask4, residue4, prime_double4);

    n_double = _mm512_add_pd(n_double,one);

    // if we are below nmin then we continue next iteration
    if (n < nmin) continue;

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)
    { // we find factor very rarely

        double *residual_list1 = (double *) &residue1;
        double *residual_list2 = (double *) &residue2;
        double *residual_list3 = (double *) &residue3;
        double *residual_list4 = (double *) &residue4;

        double *prime_list1 = (double *) &prime_double1;
        double *prime_list2 = (double *) &prime_double2;
        double *prime_list3 = (double *) &prime_double3;
        double *prime_list4 = (double *) &prime_double4;



        for (int i=0; i <8; i++){
            if( residual_list1[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list1[i]);
            }
            if( residual_list2[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list2[i]);
            }
            if( residual_list3[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list3[i]);
            }
            if( residual_list4[i] == 1.0)
            {
                report_factor((uint64_t) n, -1, (uint64_t) prime_list4[i]);
            }

            if(residual_list1[i] == (prime_list1[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list1[i]);
            }
            if(residual_list2[i] == (prime_list2[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list2[i]);
            }
            if(residual_list3[i] == (prime_list3[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list3[i]);
            }
            if(residual_list4[i] == (prime_list4[i] - 1.0))
            {
                report_factor((uint64_t) n, 1, (uint64_t) prime_list4[i]);
            }
        }
    }

}

return EXIT_SUCCESS;
}
like image 341
Nuutti Avatar asked Dec 17 '17 13:12

Nuutti


1 Answers

As a few commenters have suggested: a "backend" bottleneck is what you'd expect for this code. That suggests you're keeping things pretty well fed, which is what you want.

Looking at the report, there should be an opportunity in this section:

    // Lets check if we found any factors, residue 1 == n!-1
    found_factor_mask11 = _mm512_cmpeq_pd_mask(one, residue1);
    found_factor_mask12 = _mm512_cmpeq_pd_mask(one, residue2);
    found_factor_mask13 = _mm512_cmpeq_pd_mask(one, residue3);
    found_factor_mask14 = _mm512_cmpeq_pd_mask(one, residue4);

    // residue prime -1  == n!+1
    found_factor_mask21 = _mm512_cmpeq_pd_mask(prime_minus_one1, residue1);
    found_factor_mask22 = _mm512_cmpeq_pd_mask(prime_minus_one2, residue2);
    found_factor_mask23 = _mm512_cmpeq_pd_mask(prime_minus_one3, residue3);
    found_factor_mask24 = _mm512_cmpeq_pd_mask(prime_minus_one4, residue4);     

    if (found_factor_mask12 | found_factor_mask11 | found_factor_mask13 | found_factor_mask14 |
    found_factor_mask21 | found_factor_mask22 | found_factor_mask23|found_factor_mask24)

From the IACA analysis:

|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r11d, k0
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw eax, k1
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw ecx, k2
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw esi, k3
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw edi, k4
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r8d, k5
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r9d, k6
|   1      | 1.0         |      |             |             |      |      |      |      | kmovw r10d, k7
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, eax
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, ecx
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, esi
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, edi
|   1      |             | 1.0  |             |             |      |      |      |      | or r11d, r8d
|   1      |             |      |             |             |      |      | 1.0  |      | or r11d, r9d
|   1*     |             |      |             |             |      |      |      |      | or r11d, r10d

The processor is moving the resulting comparison masks (k0-k7) over to regular registers for the "or" operation. You should be able to eliminate those moves, AND, do the "or" rollup in 6ops vs 8.

NOTE: the found_factor_mask types are defined as __mmask8, where they should be __mask16 (16x double floats in a 512bit fector). That might let the compiler get at some optimizations. If not, drop to assembly as a commenter noted.

And related: what fraction of iteractions fire this or-mask clause? As another commenter observed, you should be able to unroll this with an accumlating "or" operation. Check the accumulated "or" value at the end of each unrolled iteration (or after N iterations), and if it's "true", go back and re-do the values to figure out which n value triggered it.

(And, you can binary search within the "roll" to find the matching n value -- that might get some gain).

Next, you should be able to get rid of this mid-loop check:

    // if we are below nmin then we continue next iteration, we
    if (n < nmin) continue;

Which shows up here:

|   1*     |             |      |             |             |      |      |      |      | cmp r14, 0x3e8
|   0*F    |             |      |             |             |      |      |      |      | jb 0x229

It may not be a huge gain since the predictor will (probably) get this one (mostly) right, but you should get some gains by having two distinct loops for two "phases":

  • n=3 to n=nmin-1
  • n=nmin and beyond

Even if you gain a cycle, that's 3%. And since that's generally related to the big 'or' operation, above, there may be more cleverness in there to be found.

like image 104
payne Avatar answered Nov 20 '22 04:11

payne