pythonnumber-theoryexponentiationprimality-test

Why is my algorithm about Fermat primality test so slow?


I am learning Number theory. Now, I want to write a program that perform Fermat primality test.

First, I write a modular square algorithm:

#modular_square.py
def modular_square(a, n, m):
    res = 1
    exp = n
    b = a 

    while exp !=0 :

        if exp % 2 == 1:
            res *= b
            res %= m

        b *= b
        exp >>= 1

    return res


def main():
    a = [   12996,      312,        501,        468,        163]
    n = [   227,        13,         13,         237,        237]
    m = [   37909,      667,        667,        667,        667]
    res = [ 7775,       468,        163,        312,        501]
    #test modular_square()
    print("===test modular_square()===")
    for i, r in enumerate(res):
        if modular_square(a[i], n[i], m[i]) != r:
            print("modular_square() failed...")
        else:
            print("modular_square({},{},{})={}".format(a[i], n[i], m[i], r))  


if __name__ == "__main__":
    main()

Then, I write Fermat primality test algorithm based on the above algorithm.

#prime_test_fermat.py

import modular_square
import random

def Fermat_base(b, n):
    res = modular_square.modular_square(b, n-1, n)

    if res == 1:
        return True
    else:
        return False


def Fermat_test(n, times):

    for i in range(times):
        b = random.randint(2, n-1)
        if Fermat_base(b, n) == False:
            return False

    return True

def main():
    b = [8,         2]
    n = [63,        63]
    res = [True,    False]
    #test Fermat_base()
    print("===test Fermat_base()===")
    for i,r in enumerate(res):
        if Fermat_base(b[i], n[i]) != res[i]:
            print("Fermat_base() failed...")
        else:
            print("Fermat_base({},{})={}".format(b[i], n[i], res[i]))

    n = [923861,
        1056420454404911
         ]

    times = [2, 2]

    res = [True,True ]
    #test Fermat_test()
    print("==test Fermat_test()===")
    for i,r in enumerate(res):
        if Fermat_test(n[i], times[i]) != res[i]:
            print("Fermat_test() failed...")
        else:
            print("Fermat_test({},{})={}".format(n[i], times[i], res[i]))

if __name__ == '__main__':
    main()

When I run prime_test_fermat.py program, it didn't stop. This is caused by Fermat primality or my code that exists bug.


Solution

  • The problem is with your modular exponentiation algorithm: The modulus is applied to res but not to b. Since b is squared in every iteration, it will become extremely large (as in several thousand digits). This slows down your algorithm.

    To solve this, you have to apply the modulus to b as well. Replace b *= b with:

    b *= b
    b %= m
    

    As an additional optimization, you can also apply the modulus when you initialize b, by replacing b = a with:

    b = a
    b %= m
    

    You can take this pseudo-code as reference.