Consider that you want to calculate the low 128-bits of the result of multiplying a 64-bit and 128-bit unsigned number, and that the largest multiplication you have available is the C-like 64-bit multiplication which takes two 64-bit unsigned inputs and returns the low 64-bits of the result.
How many multiplications are needed?
Certainly you can do it with eight: break all the inputs up into 32-bit chunks and use your 64-bit multiplication to do the 4 * 2 = 8 required full-width 32*32->64 multiplications, but can one do better?
Of course the algorithm should do only a "reasonable" number of additions or other basic arithmetic on top of the multiplications (I'm not interested in solutions that re-invent multiplication as an addition loop and hence claim "zero" multiplications).
Four, but it starts to get a little tricky.
Let a and b be the numbers to be multiplied, with a0 and a1 being the low and high 32 bits of a, respectively, and b0, b1, b2, b3 being 32-bit groups of b, from low to high respectively.
The desired result is the remainder of (a0 + a1•232) • (b0 + b1•232 + b2•264 + b3•296) modulo 2128.
We can rewrite that as (a0 + a1•232) • (b0 + b1•232) + (a0 + a1•232) • (b2•264 + b3•296) modulo 2128.
The remainder of the latter term modulo 2128 can be computed as a single 64-bit by 64-bit multiplication (whose result is implicitly multiplied by 264).
Then the former term can be computed with three multiplications using a carefully implemented Karatsuba step. The simple version would involve a 33-bit by 33-bit to 66-bit product which is not available, but there is a trickier version that avoids it:
z0 = a0 * b0
z2 = a1 * b1
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
The last line contains only one multiplication; the other two pseudo-multiplications are just conditional negations. Absolute-difference and conditional-negate are annoying to implement in pure C, but it could be done.
Of course, without Karatsuba, 5 multiplies.
Karatsuba is wonderful, but these days a 64 x 64 multiply can be over in 3 clocks and a new one can be scheduled every clock. So the overhead of dealing with the signs and what not can be significantly greater than the saving of one multiply.
For straightforward 64 x 64 multiply need:
r0 = a0*b0
r1 = a0*b1
r2 = a1*b0
r3 = a1*b1
where need to add r0 = r0 + (r1 << 32) + (r2 << 32)
and add r3 = r3 + (r1 >> 32) + (r2 >> 32) + carry
where the carry is the carry from the additions to r0, and result is r3:r0.
typedef struct { uint64_t w0, w1 ; } uint64x2_t ;
uint64x2_t
mulu64x2(uint64_t x, uint64_t m)
{
uint64x2_t r ;
uint64_t r1, r2, rx, ry ;
uint32_t x1, x0 ;
uint32_t m1, m0 ;
x1 = (uint32_t)(x >> 32) ;
x0 = (uint32_t)x ;
m1 = (uint32_t)(m >> 32) ;
m0 = (uint32_t)m ;
r1 = (uint64_t)x1 * m0 ;
r2 = (uint64_t)x0 * m1 ;
r.w0 = (uint64_t)x0 * m0 ;
r.w1 = (uint64_t)x1 * m1 ;
rx = (uint32_t)r1 ;
rx = rx + (uint32_t)r2 ; // add the ls halves, collecting carry
ry = r.w0 >> 32 ; // pick up ms of r0
r.w0 += (rx << 32) ; // complete r0
rx += ry ; // complete addition, rx >> 32 == carry !
r.w1 += (r1 >> 32) + (r2 >> 32) + (rx >> 32) ;
return r ;
}
For Karatsuba, the suggested:
z1 = abs(a0 - a1) * abs(b0 - b1) * sgn(a0 - a1) * sgn(b1 - b0) + z0 + z2
is trickier than it looks... for a start, if z1
is 64 bits, then need to somehow collect the carry which this addition can generate... and that is complicated by the signed-ness issues.
z0 = a0*b0
z1 = ax*bx -- ax = (a1 - a0), bx = (b0 - b1)
z2 = a1*b1
where need to add r0 = z0 + (z1 << 32) + (z0 << 32) + (z2 << 32)
and add r1 = z2 + (z1 >> 32) + (z0 >> 32) + (z2 >> 32) + carry
where the carry is the carry from the additions to create r0, and result is r1:r0.
where must take into account the signed-ness of ax, bx and z1.
uint64x2_t
mulu64x2_karatsuba(uint64_t a, uint64_t b)
{
uint64_t a0, a1, b0, b1 ;
uint64_t ax, bx, zx, zy ;
uint as, bs, xs ;
uint64_t z0, z2 ;
uint64x2_t r ;
a0 = (uint32_t)a ; a1 = a >> 32 ;
b0 = (uint32_t)b ; b1 = b >> 32 ;
z0 = a0 * b0 ;
z2 = a1 * b1 ;
ax = (uint64_t)(a1 - a0) ;
bx = (uint64_t)(b0 - b1) ;
as = (uint)(ax > a1) ; // sign of magic middle, a
bs = (uint)(bx > b0) ; // sign of magic middle, b
xs = (uint)(as ^ bs) ; // sign of magic middle, x = a * b
ax = (uint64_t)((ax ^ -(uint64_t)as) + as) ; // abs magic middle a
bx = (uint64_t)((bx ^ -(uint64_t)bs) + bs) ; // abs magic middle b
zx = (uint64_t)(((ax * bx) ^ -(uint64_t)xs) + xs) ;
xs = xs & (uint)(zx != 0) ; // discard sign if z1 == 0 !
zy = (uint32_t)zx ; // start ls half of z1
zy = zy + (uint32_t)z0 + (uint32_t)z2 ;
r.w0 = z0 + (zy << 32) ; // complete ls word of result.
zy = zy + (z0 >> 32) ; // complete carry
zx = (zx >> 32) - ((uint64_t)xs << 32) ; // start ms half of z1
r.w1 = z2 + zx + (z0 >> 32) + (z2 >> 32) + (zy >> 32) ;
return r ;
}
I did some very simple timings (using times()
, running on Ryzen 7 1800X):
...so, yes, you can save a multiply by using Karatsuba, but whether it's worth doing rather depends.
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With