Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Is there a way to make this function faster? (C)

I have a code in C which does additions in the same way as a human does, so if for example I have two arrays A[0..n-1] and B[0..n-1], the method will do C[0]=A[0]+B[0], C[1]=A[1]+B[1]...

I need help in making this function faster, even if the solution is using intrinsics.

My main problem is that I have a really big dependency problem, as the iteration i+1 depends on the carry of the iteration i, as long as I use base 10. So if A[0]=6 and B[0]=5, C[0] must be 1 and I have a carry of 1 for the next addition.

The faster code I could do was this one:

void LongNumAddition1(unsigned char *Vin1, unsigned char *Vin2,
                      unsigned char *Vout, unsigned N) {
    for (int i = 0; i < N; i++) {
        Vout[i] = Vin1[i] + Vin2[i];
    } 

    unsigned char carry = 0;

    for (int i = 0; i < N; i++) {
        Vout[i] += carry;
        carry = Vout[i] / 10;
        Vout[i] = Vout[i] % 10;
    }
}

But I also tried these approaches which turned out being slower:

void LongNumAddition1(unsigned char *Vin1, unsigned char *Vin2,
                      unsigned char *Vout, unsigned N) {
    unsigned char CARRY = 0;
    for (int i = 0; i < N; i++) {
        unsigned char R = Vin1[i] + Vin2[i] + CARRY;
        Vout[i] = R % 10; CARRY = R / 10;
    }
}

void LongNumAddition1(char *Vin1, char *Vin2, char *Vout, unsigned N) {
    char CARRY = 0;
    for (int i = 0; i < N; i++) {
        char R = Vin1[i] + Vin2[i] + CARRY;
        if (R <= 9) {
            Vout[i] = R;
            CARRY = 0;
        } else {
            Vout[i] = R - 10;
            CARRY = 1;
        }
    }
}

I've been researching in google and found some pseudocodes which were similar to what I have implemented, also inside GeeksforGeeks there's another implementation to this problem but it is also slower.

Can you please help me?

like image 881
Jonathan Sánchez Avatar asked Apr 16 '20 12:04

Jonathan Sánchez


2 Answers

If you don't want to change the format of the data, you can try SIMD.

typedef uint8_t u8x16 __attribute__((vector_size(16)));

void add_digits(uint8_t *const lhs, uint8_t *const rhs, uint8_t *out, size_t n) {
    uint8_t carry = 0;
    for (size_t i = 0; i + 15 < n; i += 16) {
        u8x16 digits = *(u8x16 *)&lhs[i] + *(u8x16 *)&rhs[i] + (u8x16){carry};

        // Get carries and almost-carries
        u8x16 carries = digits >= 10; // true is -1
        u8x16 full = digits == 9;

        // Shift carries
        carry = carries[15] & 1;
        __uint128_t carries_i = ((__uint128_t)carries) << 8;
        carry |= __builtin_add_overflow((__uint128_t)full, carries_i, &carries_i);

        // Add to carry chains and wrap
        digits += (((u8x16)carries_i) ^ full) & 1;
        // faster: digits = (u8x16)_mm_min_epu8((__m128i)digits, (__m128i)(digits - 10));
        digits -= (digits >= 10) & 10;

        *(u8x16 *)&out[i] = digits;
    }
}

This is ~2 instructions per digit. You'll need to add code to handle the tail-end.


Here's a run-through of the algorithm.

First, we add our digits with our carry from the last iteration:

lhs           7   3   5   9   9   2
rhs           2   4   4   9   9   7
carry                             1
         + -------------------------
digits        9   7   9  18  18  10

We calculate which digits will produce carries (≥10), and which would propagate them (=9). For whatever reason, true is -1 with SIMD.

carries       0   0   0  -1  -1  -1
full         -1   0  -1   0   0   0

We convert carries to an integer and shift it over, and also convert full to an integer.

              _   _   _   _   _   _
carries_i  000000001111111111110000
full       111100001111000000000000

Now we can add these together to propagate carries. Note that only the lowest bit is correct.

              _   _   _   _   _   _
carries_i  111100011110111111110000
(relevant) ___1___1___0___1___1___0

There are two indicators to look out for:

  1. carries_i has its lowest bit set, and digit ≠ 9. There has been a carry into this square.

  2. carries_i has its lowest bit unset, and digit = 9. There has been a carry over this square, resetting the bit.

We calculate this with (((u8x16)carries_i) ^ full) & 1, and add to digits.

(c^f) & 1     0   1   1   1   1   0
digits        9   7   9  18  18  10
         + -------------------------
digits        9   8  10  19  19  10

Then we remove the 10s, which have all been carried already.

digits        9   8  10  19  19  10
(d≥10)&10     0   0  10  10  10  10
         - -------------------------
digits        9   8   0   9   9   0

We also keep track of carries out, which can happen in two places.

like image 107
Veedrac Avatar answered Nov 14 '22 10:11

Veedrac


Candidates for speed improvement:

Optimizations

Be sure you have enabled you compiler with its speed optimizations settings.

restrict

Compiler does not know that changing Vout[] does not affect Vin1[], Vin2[] and is thus limited in certain optimizations.

Use restrict to indicate Vin1[], Vin2[] are not affected by writing to Vout[].

// void LongNumAddition1(unsigned char  *Vin1, unsigned char *Vin2, unsigned char *Vout, unsigned N)
void LongNumAddition1(unsigned char * restrict Vin1, unsigned char * restrict Vin2,
   unsigned char * restrict Vout, unsigned N)

Note: this restricts the caller from calling the function with a Vout that overlaps Vin1, Vin2.

const

Also use const to aid optimizations. const also allows const arrays to be passed as Vin1, Vin2.

// void LongNumAddition1(unsigned char * restrict Vin1, unsigned char * restrict Vin2,
   unsigned char * restrict Vout, unsigned N)
void LongNumAddition1(const unsigned char * restrict Vin1, 
   const unsigned char * restrict Vin2, 
   unsigned char * restrict Vout, 
   unsigned N)

unsigned

unsigned/int are the the "goto" types to use for integer math. Rather than unsigned char CARRY or char CARRY, use unsigned or uint_fast8_t from <inttypes.h>.

% alternative

sum = a+b+carry; if (sum >= 10) { sum -= 10; carry = 1; } else carry = 0; @pmg or the like.


Note: I would expect LongNumAddition1() to return the final carry.

like image 4
chux - Reinstate Monica Avatar answered Nov 14 '22 08:11

chux - Reinstate Monica