I'm trying to implement the divide and conquer matrix multiplication (8 recursion version not Strassen). I thought I had it figured out but it is producing weird output with too many nested lists and the wrong values. I suspect the problem is how I'm summing the 8 recursions but Im not sure.
def multiMatrix(x,y):
n = len(x)
if n == 1:
return x[0][0] * y[0][0]
else:
a = [[col for col in row[:len(row)/2]] for row in x[:len(x)/2]]
b = [[col for col in row[len(row)/2:]] for row in x[:len(x)/2]]
c = [[col for col in row[:len(row)/2]] for row in x[len(x)/2:]]
d = [[col for col in row[len(row)/2:]] for row in x[len(x)/2:]]
e = [[col for col in row[:len(row)/2]] for row in y[:len(y)/2]]
f = [[col for col in row[len(row)/2:]] for row in y[:len(y)/2]]
g = [[col for col in row[:len(row)/2]] for row in y[len(y)/2:]]
h = [[col for col in row[len(row)/2:]] for row in y[len(y)/2:]]
ae = multiMatrix(a,e)
bg = multiMatrix(b,g)
af = multiMatrix(a,f)
bh = multiMatrix(b,h)
ce = multiMatrix(c,e)
dg = multiMatrix(d,g)
cf = multiMatrix(c,f)
dh = multiMatrix(d,h)
c = [[ae+bg,af+bh],[ce+dg,cf+dh]]
return c
a = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
b = [
[1,2,3,4],
[5,6,7,8],
[9,10,11,12],
[13,14,15,16]
]
print multiMatrix(a,b)
Your suspicion is correct, your matrices are still lists, so adding them will just make a longer list.
Try using something like this
def matrix_add(a, b):
return [[ea+eb for ea, eb in zip(*rowpair)] for rowpair in zip(a, b)]
in your code.
To join blocks:
def join_horiz(a, b):
return [rowa + rowb for rowa, rowb in zip(a,b)]
def join_vert(a, b):
return a+b
Finally, to make it all work together I think you have to change your special case for 1 to
return [[x[0][0] * y[0][0]]]
Edit:
I just realised that this will only work for power-of-two dimensions. Otherwise you will have to deal with non-square matrices and it will happen that x
is 1 x something and your special case won't work. So you'll also have to check for len(x[0]) (if n > 0).