I am trying to implement efficient prime factorization algorithms in Python. This is not homework or work related, it is completely out of curiosity.
I have learned that prime factorization is very hard:
I want to implement efficient algorithms for this as a self-imposed challenge. I have set to implement Fermat's factorization method first as it seems simple enough.
Python code directly translated from the pseudocode:
def Fermat_Factor(n):
a = int(n ** 0.5 + 0.5)
b2 = abs(a**2 - n)
while int(b2**0.5) ** 2 != b2:
a += 1
b2 = a**2 - n
return a - b2**0.5, a + b2**0.5
(I have to use abs
otherwise b2
will easily be negative and int
cast will fail with TypeError
because the root is complex
)
As you can see, it returns two integers whose product equals the input, but it only returns two outputs and it doesn't guarantee primality of the factors. I have no idea how efficient this algorithm is, but factorization of semiprimes using this method is much more efficient than the trial division method used in my previous question: Why factorization of products of close primes is much slower than products of dissimilar primes.
In [20]: %timeit FermatFactor(3607*3803)
2.1 μs ± 28.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [21]: FermatFactor(3607*3803)
Out[21]: [3607, 3803]
In [22]: %timeit FermatFactor(3593 * 3671)
1.69 μs ± 31 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [23]: FermatFactor(3593 * 3671)
Out[23]: [3593, 3671]
In [24]: %timeit FermatFactor(7187 * 7829)
4.94 μs ± 47.4 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)
In [25]: FermatFactor(7187 * 7829)
Out[25]: [7187, 7829]
In [26]: %timeit FermatFactor(8087 * 8089)
1.38 μs ± 12.9 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)
In [27]: FermatFactor(8087 * 8089)
Out[27]: [8087, 8089]
So I want to use this algorithm to generate all prime factors of a any given integer (of course I know this only works with odd integers, but that is not an issue since powers of two can be trivially factored out using bit hacking). The easiest way I can think of is to recursively call Fermat_Factor
until n
is a prime. I don't know how to check if a number is prime in this algorithm, but I noticed something:
In [3]: Fermat_Factor(3)
Out[3]: (1.0, 3.0)
In [4]: Fermat_Factor(5)
Out[4]: (1.0, 3.0)
In [5]: Fermat_Factor(7)
Out[5]: (1.0, 7.0)
In [6]: Fermat_Factor(11)
Out[6]: (1.0, 11.0)
In [7]: Fermat_Factor(13)
Out[7]: (1.0, 13.0)
In [8]: Fermat_Factor(17)
Out[8]: (3.0, 5.0)
In [9]: Fermat_Factor(19)
Out[9]: (1.0, 19.0)
In [10]: Fermat_Factor(23)
Out[10]: (1.0, 23.0)
In [11]: Fermat_Factor(29)
Out[11]: (3.0, 7.0)
In [12]: Fermat_Factor(31)
Out[12]: (1.0, 31.0)
In [13]: Fermat_Factor(37)
Out[13]: (5.0, 7.0)
In [14]: Fermat_Factor(41)
Out[14]: (1.0, 41.0)
The first number in the output of this algorithm for many primes is 1, but not all, as such it cannot be used to determine when the recursion should stop. I learned it the hard way.
So I just settled to use membership checking of a pregenerated set of primes instead. Naturally this will cause RecursionError: maximum recursion depth exceeded
when the input is a prime larger than the maximum of the set. As I don't have infinite memory, this is to be considered implementation detail.
So I have implemented a working version (for some inputs), but for some valid inputs (products of primes within the limit) somehow the algorithm doesn't give the correct output:
import numpy as np
from itertools import cycle
TRIPLE = ((4, 2), (9, 6), (25, 10))
WHEEL = ( 4, 2, 4, 2, 4, 6, 2, 6 )
def prime_sieve(n):
primes = np.ones(n + 1, dtype=bool)
primes[:2] = False
for square, double in TRIPLE:
primes[square::double] = False
wheel = cycle(WHEEL)
k = 7
while (square := k**2) <= n:
if primes[k]:
primes[square::2*k] = False
k += next(wheel)
return np.flatnonzero(primes)
PRIMES = list(map(int, prime_sieve(1048576)))
PRIME_SET = set(PRIMES)
TEST_LIMIT = PRIMES[-1] ** 2
def FermatFactor(n):
if n > TEST_LIMIT:
raise ValueError('Number too large')
if n in PRIME_SET:
return [n]
a = int(n ** 0.5 + 0.5)
if a ** 2 == n:
return FermatFactor(a) + FermatFactor(a)
b2 = abs(a**2 - n)
while int(b2**0.5) ** 2 != b2:
a += 1
b2 = a**2 - n
return FermatFactor(factor := int(a - b2**0.5)) + FermatFactor(n // factor)
It works for many inputs:
In [18]: FermatFactor(255)
Out[18]: [3, 5, 17]
In [19]: FermatFactor(511)
Out[19]: [7, 73]
In [20]: FermatFactor(441)
Out[20]: [3, 7, 3, 7]
In [21]: FermatFactor(3*5*823)
Out[21]: [3, 5, 823]
In [22]: FermatFactor(37*333667)
Out[22]: [37, 333667]
In [23]: FermatFactor(13 * 37 * 151 * 727 * 3607)
Out[23]: [13, 37, 727, 151, 3607]
But not all:
In [25]: FermatFactor(5 * 53 * 163)
Out[25]: [163, 13, 2, 2, 5]
In [26]: FermatFactor(3*5*73*283)
Out[26]: [17, 3, 7, 3, 283]
In [27]: FermatFactor(3 * 11 * 29 * 71 * 137)
Out[27]: [3, 11, 71, 61, 7, 3, 3]
Why is it this case? How can I fix it?
You're supposed to start with a ← ceiling(sqrt(N))
, not a = int(n ** 0.5 + 0.5)
. At the very least use a = math.ceil(n ** 0.5)
instead, then Fermat_Factor(17)
already gives (1.0, 17.0)
instead of (3.0, 5.0)
. But really better stay away from floats, use math.isqrt
. And of course you don't need abs
if you actually compute the ceiling.
from math import isqrt
def Fermat_Factor(n):
a = 1 + isqrt(n - 1)
b2 = a**2 - n
while isqrt(b2) ** 2 != b2:
a += 1
b2 = a**2 - n
return a - isqrt(b2), a + isqrt(b2)