Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How many 64-bit multiplications are needed to calculate the low 128-bits of a 64-bit by 128-bit product?

Consider that you want to calculate the low 128-bits of the result of multiplying a 64-bit and 128-bit unsigned number, and that the largest multiplication you have available is the C-like 64-bit multiplication which takes two 64-bit unsigned inputs and returns the low 64-bits of the result.

How many multiplications are needed?

Certainly you can do it with eight: break all the inputs up into 32-bit chunks and use your 64-bit multiplication to do the 4 * 2 = 8 required full-width 32*32->64 multiplications, but can one do better?

Of course the algorithm should do only a "reasonable" number of additions or other basic arithmetic on top of the multiplications (I'm not interested in solutions that re-invent multiplication as an addition loop and hence claim "zero" multiplications).

like image 662
BeeOnRope Avatar asked Aug 17 '18 19:08

BeeOnRope


2 Answers

Four, but it starts to get a little tricky.

Let a and b be the numbers to be multiplied, with a0 and a1 being the low and high 32 bits of a, respectively, and b0, b1, b2, b3 being 32-bit groups of b, from low to high respectively.

The desired result is the remainder of (a0 + a1•232) • (b0 + b1•232 + b2•264 + b3•296) modulo 2128.

We can rewrite that as (a0 + a1•232) • (b0 + b1•232) + (a0 + a1•232) • (b2•264 + b3•296) modulo 2128.

The remainder of the latter term modulo 2128 can be computed as a single 64-bit by 64-bit multiplication (whose result is implicitly multiplied by 264).

Then the former term can be computed with three multiplications using a carefully implemented Karatsuba step. The simple version would involve a 33-bit by 33-bit to 66-bit product which is not available, but there is a trickier version that avoids it:

z0 = a0 * b0
z2 = a1 * b1
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2

The last line contains only one multiplication; the other two pseudo-multiplications are just conditional negations. Absolute-difference and conditional-negate are annoying to implement in pure C, but it could be done.

like image 177
harold Avatar answered Nov 05 '22 20:11

harold


Of course, without Karatsuba, 5 multiplies.

Karatsuba is wonderful, but these days a 64 x 64 multiply can be over in 3 clocks and a new one can be scheduled every clock. So the overhead of dealing with the signs and what not can be significantly greater than the saving of one multiply.

For straightforward 64 x 64 multiply need:

    r0 =       a0*b0
    r1 =    a0*b1
    r2 =    a1*b0
    r3 = a1*b1

  where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
            and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry

  where the carry is the carry from the additions to r0, and result is r3:r0.

typedef struct { uint64_t w0, w1 ; } uint64x2_t ;

uint64x2_t
mulu64x2(uint64_t x, uint64_t m)
{
  uint64x2_t r ;
  uint64_t r1, r2, rx, ry ;
  uint32_t x1, x0 ;
  uint32_t m1, m0 ;

  x1    = (uint32_t)(x >> 32) ;
  x0    = (uint32_t)x ;
  m1    = (uint32_t)(m >> 32) ;
  m0    = (uint32_t)m ;

  r1    = (uint64_t)x1 * m0 ;
  r2    = (uint64_t)x0 * m1 ;
  r.w0  = (uint64_t)x0 * m0 ;
  r.w1  = (uint64_t)x1 * m1 ;

  rx    = (uint32_t)r1 ;
  rx    = rx + (uint32_t)r2 ;    // add the ls halves, collecting carry
  ry    = r.w0 >> 32 ;           // pick up ms of r0
  r.w0 += (rx << 32) ;           // complete r0
  rx   += ry ;                   // complete addition, rx >> 32 == carry !

  r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;

  return r ;
}

For Karatsuba, the suggested:

z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2

is trickier than it looks... for a start, if z1 is 64 bits, then need to somehow collect the carry which this addition can generate... and that is complicated by the signed-ness issues.

    z0 =       a0*b0
    z1 =    ax*bx        -- ax = (a1 - a0), bx = (b0 - b1)
    z2 = a1*b1

  where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
            and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry

  where the carry is the carry from the additions to create r0, and result is r1:r0.

  where must take into account the signed-ness of ax, bx and z1. 

uint64x2_t
mulu64x2_karatsuba(uint64_t a, uint64_t b)
{
  uint64_t a0, a1, b0, b1 ;
  uint64_t ax, bx, zx, zy ;
  uint     as, bs, xs ;
  uint64_t z0, z2 ;
  uint64x2_t r ;

  a0 = (uint32_t)a ; a1 = a >> 32 ;
  b0 = (uint32_t)b ; b1 = b >> 32 ;

  z0 = a0 * b0 ;
  z2 = a1 * b1 ;

  ax = (uint64_t)(a1 - a0) ;
  bx = (uint64_t)(b0 - b1) ;

  as = (uint)(ax > a1) ;                // sign of magic middle, a
  bs = (uint)(bx > b0) ;                // sign of magic middle, b
  xs = (uint)(as ^ bs) ;                // sign of magic middle, x = a * b

  ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ;  // abs magic middle a
  bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ;  // abs magic middle b

  zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
  xs = xs & (uint)(zx != 0) ;           // discard sign if z1 == 0 !

  zy = (uint32_t)zx ;                   // start ls half of z1
  zy = zy + (uint32_t)z0 + (uint32_t)z2 ;

  r.w0 = z0 + (zy << 32) ;              // complete ls word of result.
  zy   = zy + (z0 >> 32) ;              // complete carry

  zx   = (zx >> 32) - ((uint64_t)xs << 32) ;   // start ms half of z1
  r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;

  return r ;
}

I did some very simple timings (using times(), running on Ryzen 7 1800X):

  • using gcc __int128................... ~780 'units'
  • using mulu64x2()..................... ~895
  • using mulu64x2_karatsuba()... ~1,095

...so, yes, you can save a multiply by using Karatsuba, but whether it's worth doing rather depends.

like image 26
Chris Hall Avatar answered Nov 05 '22 21:11

Chris Hall