Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Most accurate way to do a combined multiply-and-divide operation in 64-bit?

Tags:

What is the most accurate way I can do a multiply-and-divide operation for 64-bit integers that works in both 32-bit and 64-bit programs (in Visual C++)? (In case of overflow, I need the result mod 264.)

(I'm looking for something like MulDiv64, except that this one uses inline assembly, which only works in 32-bit programs.)

Obviously, casting to double and back is possible, but I'm wondering if there's a more accurate way that isn't too complicated. (i.e. I'm not looking for arbitrary-precision arithmetic libraries here!)

like image 810
user541686 Avatar asked Jan 04 '12 19:01

user541686


2 Answers

Since this is tagged Visual C++ I'll give a solution that abuses MSVC-specific intrinsics.

This example is fairly complicated. It's a highly simplified version of the same algorithm that is used by GMP and java.math.BigInteger for large division.

Although I have a simpler algorithm in mind, it's probably about 30x slower.

This solution has the following constraints/behavior:

  • It requires x64. It will not compile on x86.
  • The quotient is not zero.
  • The quotient saturates if it overflows 64-bits.

Note that this is for the unsigned integer case. It's trivial to build a wrapper around this to make it work for signed cases as well. This example should also produce correctly truncated results.

This code is not fully tested. However, it has passed all the tests cases that I've thrown at it.
(Even cases that I've intentionally constructed to try to break the algorithm.)

#include <intrin.h>  uint64_t muldiv2(uint64_t a, uint64_t b, uint64_t c){     //  Normalize divisor     unsigned long shift;     _BitScanReverse64(&shift,c);     shift = 63 - shift;      c <<= shift;      //  Multiply     a = _umul128(a,b,&b);     if (((b << shift) >> shift) != b){         cout << "Overflow" << endl;         return 0xffffffffffffffff;     }     b = __shiftleft128(a,b,shift);     a <<= shift;       uint32_t div;     uint32_t q0,q1;     uint64_t t0,t1;      //  1st Reduction     div = (uint32_t)(c >> 32);     t0 = b / div;     if (t0 > 0xffffffff)         t0 = 0xffffffff;     q1 = (uint32_t)t0;     while (1){         t0 = _umul128(c,(uint64_t)q1 << 32,&t1);         if (t1 < b || (t1 == b && t0 <= a))             break;         q1--; //        cout << "correction 0" << endl;     }     b -= t1;     if (t0 > a) b--;     a -= t0;      if (b > 0xffffffff){         cout << "Overflow" << endl;         return 0xffffffffffffffff;     }      //  2nd reduction     t0 = ((b << 32) | (a >> 32)) / div;     if (t0 > 0xffffffff)         t0 = 0xffffffff;     q0 = (uint32_t)t0;      while (1){         t0 = _umul128(c,q0,&t1);         if (t1 < b || (t1 == b && t0 <= a))             break;         q0--; //        cout << "correction 1" << endl;     }  //    //  (a - t0) gives the modulus. //    a -= t0;      return ((uint64_t)q1 << 32) | q0; } 

Note that if you don't need a perfectly truncated result, you can remove the last loop completely. If you do this, the answer will be no more than 2 larger than the correct quotient.

Test Cases:

cout << muldiv2(4984198405165151231,6132198419878046132,9156498145135109843) << endl; cout << muldiv2(11540173641653250113, 10150593219136339683, 13592284235543989460) << endl; cout << muldiv2(449033535071450778, 3155170653582908051, 4945421831474875872) << endl; cout << muldiv2(303601908757, 829267376026, 659820219978) << endl; cout << muldiv2(449033535071450778, 829267376026, 659820219978) << endl; cout << muldiv2(1234568, 829267376026, 1) << endl; cout << muldiv2(6991754535226557229, 7798003721120799096, 4923601287520449332) << endl; cout << muldiv2(9223372036854775808, 2147483648, 18446744073709551615) << endl; cout << muldiv2(9223372032559808512, 9223372036854775807, 9223372036854775807) << endl; cout << muldiv2(9223372032559808512, 9223372036854775807, 12) << endl; cout << muldiv2(18446744073709551615, 18446744073709551615, 9223372036854775808) << endl; 

Output:

3337967539561099935 8618095846487663363 286482625873293138 381569328444 564348969767547451 1023786965885666768 11073546515850664288 1073741824 9223372032559808512 Overflow 18446744073709551615 Overflow 18446744073709551615 
like image 97
Mysticial Avatar answered Oct 01 '22 11:10

Mysticial


You just need 64 bits integers. There are some redundant operations but that allows to use 10 as base and step in the debugger.

uint64_t const base = 1ULL<<32; uint64_t const maxdiv = (base-1)*base + (base-1);  uint64_t multdiv(uint64_t a, uint64_t b, uint64_t c) {     // First get the easy thing     uint64_t res = (a/c) * b + (a%c) * (b/c);     a %= c;     b %= c;     // Are we done?     if (a == 0 || b == 0)         return res;     // Is it easy to compute what remain to be added?     if (c < base)         return res + (a*b/c);     // Now 0 < a < c, 0 < b < c, c >= 1ULL     // Normalize     uint64_t norm = maxdiv/c;     c *= norm;     a *= norm;     // split into 2 digits     uint64_t ah = a / base, al = a % base;     uint64_t bh = b / base, bl = b % base;     uint64_t ch = c / base, cl = c % base;     // compute the product     uint64_t p0 = al*bl;     uint64_t p1 = p0 / base + al*bh;     p0 %= base;     uint64_t p2 = p1 / base + ah*bh;     p1 = (p1 % base) + ah * bl;     p2 += p1 / base;     p1 %= base;     // p2 holds 2 digits, p1 and p0 one      // first digit is easy, not null only in case of overflow     uint64_t q2 = p2 / c;     p2 = p2 % c;      // second digit, estimate     uint64_t q1 = p2 / ch;     // and now adjust     uint64_t rhat = p2 % ch;     // the loop can be unrolled, it will be executed at most twice for     // even bases -- three times for odd one -- due to the normalisation above     while (q1 >= base || (rhat < base && q1*cl > rhat*base+p1)) {         q1--;         rhat += ch;     }     // subtract      p1 = ((p2 % base) * base + p1) - q1 * cl;     p2 = (p2 / base * base + p1 / base) - q1 * ch;     p1 = p1 % base + (p2 % base) * base;      // now p1 hold 2 digits, p0 one and p2 is to be ignored     uint64_t q0 = p1 / ch;     rhat = p1 % ch;     while (q0 >= base || (rhat < base && q0*cl > rhat*base+p0)) {         q0--;         rhat += ch;     }     // we don't need to do the subtraction (needed only to get the remainder,     // in which case we have to divide it by norm)     return res + q0 + q1 * base; // + q2 *base*base } 
like image 30
AProgrammer Avatar answered Oct 01 '22 12:10

AProgrammer