I am trying to implement the Karatsuba multiplication algorithm in c++ but right now I am just trying to get it to work in python.
Here is my code:
def mult(x, y, b, m):
if max(x, y) < b:
return x * y
bm = pow(b, m)
x0 = x / bm
x1 = x % bm
y0 = y / bm
y1 = y % bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
What I don't get is: how should z2
, z1
, and z0
be created? Is using the mult
function recursively correct? If so, I'm messing up somewhere because the recursion isn't stopping.
Can someone point out where the error is?
Karatsuba algorithm can be recursively applied with dividing multiplicand and multiplier into two parts for large- digit multiplication operation. The cost of multiplication operation is higher than the cost of addition and subtraction operations (Dwivedi, 2013).
To multiply two n n n-bit numbers, x x x and y y y, the Karatsuba algorithm performs three multiplications and a few additions, and shifts on smaller numbers that are roughly half the size of the original x x x and y y y.
This algorithm takes O(n^2) time. Using Divide and Conquer, we can multiply two integers in less time complexity.
NB: the response below addresses directly the OP's question about excessive recursion, but it does not attempt to provide a correct Karatsuba algorithm. The other responses are far more informative in this regard.
Try this version:
def mult(x, y, b, m):
bm = pow(b, m)
if min(x, y) <= bm:
return x * y
# NOTE the following 4 lines
x0 = x % bm
x1 = x / bm
y0 = y % bm
y1 = y / bm
z0 = mult(x0, y0, b, m)
z2 = mult(x1, y1, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
retval = mult(mult(z2, bm, b, m) + z1, bm, b, m) + z0
assert retval == x * y, "%d * %d == %d != %d" % (x, y, x * y, retval)
return retval
The most serious problem with your version is that your calculations of x0 and x1, and of y0 and y1 are flipped. Also, the algorithm's derivation does not hold if x1
and y1
are 0, because in this case, a factorization step becomes invalid. Therefore, you must avoid this possibility by ensuring that both x and y are greater than b**m.
EDIT: fixed a typo in the code; added clarifications
EDIT2:
To be clearer, commenting directly on your original version:
def mult(x, y, b, m):
# The termination condition will never be true when the recursive
# call is either
# mult(z2, bm ** 2, b, m)
# or mult(z1, bm, b, m)
#
# Since every recursive call leads to one of the above, you have an
# infinite recursion condition.
if max(x, y) < b:
return x * y
bm = pow(b, m)
# Even without the recursion problem, the next four lines are wrong
x0 = x / bm # RHS should be x % bm
x1 = x % bm # RHS should be x / bm
y0 = y / bm # RHS should be y % bm
y1 = y % bm # RHS should be y / bm
z2 = mult(x1, y1, b, m)
z0 = mult(x0, y0, b, m)
z1 = mult(x1 + x0, y1 + y0, b, m) - z2 - z0
return mult(z2, bm ** 2, b, m) + mult(z1, bm, b, m) + z0
Usually big numbers are stored as arrays of integers. Each integer represents one digit. This approach allows to multiply any number by the power of base with simple left shift of the array.
Here is my list-based implementation (may contain bugs):
def normalize(l,b):
over = 0
for i,x in enumerate(l):
over,l[i] = divmod(x+over,b)
if over: l.append(over)
return l
def sum_lists(x,y,b):
l = min(len(x),len(y))
res = map(operator.add,x[:l],y[:l])
if len(x) > l: res.extend(x[l:])
else: res.extend(y[l:])
return normalize(res,b)
def sub_lists(x,y,b):
res = map(operator.sub,x[:len(y)],y)
res.extend(x[len(y):])
return normalize(res,b)
def lshift(x,n):
if len(x) > 1 or len(x) == 1 and x[0] != 0:
return [0 for i in range(n)] + x
else: return x
def mult_lists(x,y,b):
if min(len(x),len(y)) == 0: return [0]
m = max(len(x),len(y))
if (m == 1): return normalize([x[0]*y[0]],b)
else: m >>= 1
x0,x1 = x[:m],x[m:]
y0,y1 = y[:m],y[m:]
z0 = mult_lists(x0,y0,b)
z1 = mult_lists(x1,y1,b)
z2 = mult_lists(sum_lists(x0,x1,b),sum_lists(y0,y1,b),b)
t1 = lshift(sub_lists(z2,sum_lists(z1,z0,b),b),m)
t2 = lshift(z1,m*2)
return sum_lists(sum_lists(z0,t1,b),t2,b)
sum_lists
and sub_lists
returns unnormalized result - single digit can be greater than the base value. normalize
function solved this problem.
All functions expect to get list of digits in the reverse order. For example 12 in base 10 should be written as [2,1]. Lets take a square of 9987654321.
» a = [1,2,3,4,5,6,7,8,9]
» res = mult_lists(a,a,10)
» res.reverse()
» res
[9, 7, 5, 4, 6, 1, 0, 5, 7, 7, 8, 9, 9, 7, 1, 0, 4, 1]
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With