pythonpython-3.xalgorithmrecursionstrassen

Strassen matrix multiplication in python


def matrix_addition(A, B):
    # Check if matrices have the same size
    if len(A) != len(B) or len(A[0]) != len(B[0]):
        raise ValueError("Matrices must have the same size")

    # Initialize result matrix with zeros
    result = [[0 for col in range(len(A[0]))] for row in range(len(A))]

    # Add matrices element-wise
    for row in range(len(A)):
        for col in range(len(A[0])):
            result[row][col] = A[row][col] + B[row][col]

    return result


def matrix_subtraction(A, B):
    # Check if matrices have the same size
    if len(A) != len(B) or len(A[0]) != len(B[0]):
        raise ValueError("Matrices must have the same size")

    # Initialize result matrix with zeros
    result = [[0 for col in range(len(A[0]))] for row in range(len(A))]

    # Subtract matrices element-wise
    for row in range(len(A)):
        for col in range(len(A[0])):
            result[row][col] = A[row][col] - B[row][col]

    return result
def strassen(a, b):
    if len(A) == 1 and len(A[0]) == 1:
        return a[0][0] * b[0][0]
    else:
        # divide into quadrants
        quad1_a, quad2_a, quad3_a, quad4_a = divide(a)
        quad1_b, quad2_b, quad3_b, quad4_b = divide(b)
        #break into parts to compute 
        p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
        p2 = strassen(matrix_addition(quad3_a + quad4_a), quad1_b)
        p3 = strassen(quad1_a, matrix_subtraction(quad3_b, quad1_b))
        p4 = strassen(quad4_a, matrix_subtraction(quad3_b, quad1_b))
        p5 = strassen(matrix_addition(quad1_a, quad2_a), quad4_b)
        p6 = strassen(matrix_subtraction(quad3_a, quad1_a), matrix_addition(quad1_b, quad2_b))
        p7 = strassen(matrix_subtraction(quad2_a, quad4_a), matrix_addition(quad3_b, quad4_b))
        # create the final matrix
        final_quad1 = matrix_subtraction(matrix_addition(p1, p4), matrix_addition(p5, p7))
        final_quad2 = matrix_addition(p3, p5)
        final_quad3 = matrix_addition(p2, p4)
        final_quad4 = matrix_addition(matrix_subtraction(p1, p2), matrix_addition(p3, p6))
        resultant_matrix = combine_submatrices(final_quad1, final_quad2, final_quad3, final_quad4)
        return resultant_matrix

its a basic implementation of the strassen algorithm i have tested all the secondary function and they work but joined together i keep running into problems.

the strassen function is supposed to take 2 2d arrays of 2^n size for the above code i used the arrays

A = [[1, 2, 3, 4],
     [5, 6, 7, 8],
     [9, 10, 11, 12],
     [13, 14, 15, 16]]

B = [[17, 18, 19, 20],
     [21, 22, 23, 24],
     [25, 26, 27, 28],
     [29, 30, 31, 32]]

the result should be this

C = [[250, 260, 270, 280],
     [618, 644, 670, 696],
     [986, 1028, 1070, 1112],
     [1354, 1412, 1470, 1528]]

i have ran the code multiple times and i get into the problem

    raise ValueError("Matrices must have the same size")
ValueError: Matrices must have the same size

Process finished with exit code 1

if i turn the exception code off i run into a different problem

Traceback (most recent call last):
  File "strassen_matrix.py", line 112, in <module>
    print(strassen(A, B))
          ^^^^^^^^^^^^^^
  File "strassen_matrix.py", line 97, in strassen
    p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "strassen_matrix.py", line 97, in strassen
    p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "strassen_matrix.py", line 97, in strassen
    p1 = strassen(matrix_addition(quad1_a, quad4_a), matrix_addition(quad1_b, quad4_b))
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  [Previous line repeated 994 more times]
  File "strassen_matrix.py", line 95, in strassen
    quad1_a, quad2_a, quad3_a, quad4_a = divide(a)
                                         ^^^^^^^^^
  File "strassen_matrix.py", line 20, in divide
    for x in range(int(len(matrix) / 2)):
                   ^^^^^^^^^^^^^^^^^^^^
RecursionError: maximum recursion depth exceeded while calling a Python objectquer

i tried increasing the recursion limit aswell but still same issue im stumped as to how to fix this any help is appreciated


Solution

  • You have a few errors:

    After fixing those issues, it will work.