pythonalgorithmprime-factoring

How to correctly implement Fermat's factorization in Python?


I am trying to implement efficient prime factorization algorithms in Python. This is not homework or work related, it is completely out of curiosity.

I have learned that prime factorization is very hard:

I want to implement efficient algorithms for this as a self-imposed challenge. I have set to implement Fermat's factorization method first as it seems simple enough.

Python code directly translated from the pseudocode:

def Fermat_Factor(n):
    a = int(n ** 0.5 + 0.5)
    b2 = abs(a**2 - n)
    while int(b2**0.5) ** 2 != b2:
        a += 1
        b2 = a**2 - n

    return a - b2**0.5, a + b2**0.5

(I have to use abs otherwise b2 will easily be negative and int cast will fail with TypeError because the root is complex)

As you can see, it returns two integers whose product equals the input, but it only returns two outputs and it doesn't guarantee primality of the factors. I have no idea how efficient this algorithm is, but factorization of semiprimes using this method is much more efficient than the trial division method used in my previous question: Why factorization of products of close primes is much slower than products of dissimilar primes.

In [20]: %timeit FermatFactor(3607*3803)
2.1 μs ± 28.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [21]: FermatFactor(3607*3803)
Out[21]: [3607, 3803]

In [22]: %timeit FermatFactor(3593 * 3671)
1.69 μs ± 31 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [23]: FermatFactor(3593 * 3671)
Out[23]: [3593, 3671]

In [24]: %timeit FermatFactor(7187 * 7829)
4.94 μs ± 47.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)

In [25]: FermatFactor(7187 * 7829)
Out[25]: [7187, 7829]

In [26]: %timeit FermatFactor(8087 * 8089)
1.38 μs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)

In [27]: FermatFactor(8087 * 8089)
Out[27]: [8087, 8089]

So I want to use this algorithm to generate all prime factors of a any given integer (of course I know this only works with odd integers, but that is not an issue since powers of two can be trivially factored out using bit hacking). The easiest way I can think of is to recursively call Fermat_Factor until n is a prime. I don't know how to check if a number is prime in this algorithm, but I noticed something:

In [3]: Fermat_Factor(3)
Out[3]: (1.0, 3.0)

In [4]: Fermat_Factor(5)
Out[4]: (1.0, 3.0)

In [5]: Fermat_Factor(7)
Out[5]: (1.0, 7.0)

In [6]: Fermat_Factor(11)
Out[6]: (1.0, 11.0)

In [7]: Fermat_Factor(13)
Out[7]: (1.0, 13.0)

In [8]: Fermat_Factor(17)
Out[8]: (3.0, 5.0)

In [9]: Fermat_Factor(19)
Out[9]: (1.0, 19.0)

In [10]: Fermat_Factor(23)
Out[10]: (1.0, 23.0)

In [11]: Fermat_Factor(29)
Out[11]: (3.0, 7.0)

In [12]: Fermat_Factor(31)
Out[12]: (1.0, 31.0)

In [13]: Fermat_Factor(37)
Out[13]: (5.0, 7.0)

In [14]: Fermat_Factor(41)
Out[14]: (1.0, 41.0)

The first number in the output of this algorithm for many primes is 1, but not all, as such it cannot be used to determine when the recursion should stop. I learned it the hard way.

So I just settled to use membership checking of a pregenerated set of primes instead. Naturally this will cause RecursionError: maximum recursion depth exceeded when the input is a prime larger than the maximum of the set. As I don't have infinite memory, this is to be considered implementation detail.

So I have implemented a working version (for some inputs), but for some valid inputs (products of primes within the limit) somehow the algorithm doesn't give the correct output:

import numpy as np
from itertools import cycle

TRIPLE = ((4, 2), (9, 6), (25, 10))
WHEEL = ( 4, 2, 4, 2, 4, 6, 2, 6 )
def prime_sieve(n):
    primes = np.ones(n + 1, dtype=bool)
    primes[:2] = False
    for square, double in TRIPLE:
        primes[square::double] = False
    
    wheel = cycle(WHEEL)
    k = 7
    while (square := k**2) <= n:
        if primes[k]:
            primes[square::2*k] = False
        
        k += next(wheel)
    
    return np.flatnonzero(primes)
    
PRIMES = list(map(int, prime_sieve(1048576)))
PRIME_SET = set(PRIMES)
TEST_LIMIT = PRIMES[-1] ** 2

def FermatFactor(n):
    if n > TEST_LIMIT:
        raise ValueError('Number too large')
    
    if n in PRIME_SET:
        return [n]
    
    a = int(n ** 0.5 + 0.5)
    if a ** 2 == n:
        return FermatFactor(a) + FermatFactor(a)
    
    b2 = abs(a**2 - n)
    while int(b2**0.5) ** 2 != b2:
        a += 1
        b2 = a**2 - n
    
    return FermatFactor(factor := int(a - b2**0.5)) + FermatFactor(n // factor)

It works for many inputs:

In [18]: FermatFactor(255)
Out[18]: [3, 5, 17]

In [19]: FermatFactor(511)
Out[19]: [7, 73]

In [20]: FermatFactor(441)
Out[20]: [3, 7, 3, 7]

In [21]: FermatFactor(3*5*823)
Out[21]: [3, 5, 823]

In [22]: FermatFactor(37*333667)
Out[22]: [37, 333667]

In [23]: FermatFactor(13 * 37 * 151 * 727 * 3607)
Out[23]: [13, 37, 727, 151, 3607]

But not all:

In [25]: FermatFactor(5 * 53 * 163)
Out[25]: [163, 13, 2, 2, 5]

In [26]: FermatFactor(3*5*73*283)
Out[26]: [17, 3, 7, 3, 283]

In [27]: FermatFactor(3 * 11 * 29 * 71 *  137)
Out[27]: [3, 11, 71, 61, 7, 3, 3]

Why is it this case? How can I fix it?


Solution

  • You're supposed to start with a ← ceiling(sqrt(N)), not a = int(n ** 0.5 + 0.5). At the very least use a = math.ceil(n ** 0.5) instead, then Fermat_Factor(17) already gives (1.0, 17.0) instead of (3.0, 5.0). But really better stay away from floats, use math.isqrt. And of course you don't need abs if you actually compute the ceiling.

    from math import isqrt
    
    def Fermat_Factor(n):
        a = 1 + isqrt(n - 1)
        b2 = a**2 - n
        while isqrt(b2) ** 2 != b2:
            a += 1
            b2 = a**2 - n
        return a - isqrt(b2), a + isqrt(b2)
    

    Attempt This Online!