pythonalgorithmmatheuclidean-algorithm

What is the algorithm behind math.gcd and why it is faster Euclidean algorithm?


Tests shows that Python's math.gcd is one order faster than naive Euclidean algorithm implementation:

import math
from timeit import default_timer as timer

def gcd(a,b):
        while b != 0:
                a, b = b, a % b
        return a

def main():
        a = 28871271685163
        b = 17461204521323
        start = timer()
        print(gcd(a, b))
        end = timer()
        print(end - start)

        start = timer()
        print(math.gcd(a, b))
        end = timer()
        print(end - start)

gives

$ python3 test.py
1
4.816000000573695e-05
1
8.346003596670926e-06

e-05 vs e-06.

I guess there is some optimizations or some other algorithm?


Solution

  • math.gcd() is certainly a Python shim over a library function that is running as machine code (i.e. compiled from "C" code), not a function being run by the Python interpreter. See also: Where are math.py and sys.py?

    This should be it (for CPython):

    math_gcd(PyObject *module, PyObject * const *args, Py_ssize_t nargs)

    in mathmodule.c

    and it calls

    _PyLong_GCD(PyObject *aarg, PyObject *barg)

    in longobject.c

    which apparently uses Lehmer's GCD algorithm

    The code is smothered in housekeeping operations and handling of special case though, increasing the complexity considerably. Still, quite clean.

    PyObject *
    _PyLong_GCD(PyObject *aarg, PyObject *barg)
    {
        PyLongObject *a, *b, *c = NULL, *d = NULL, *r;
        stwodigits x, y, q, s, t, c_carry, d_carry;
        stwodigits A, B, C, D, T;
        int nbits, k;
        digit *a_digit, *b_digit, *c_digit, *d_digit, *a_end, *b_end;
    
        a = (PyLongObject *)aarg;
        b = (PyLongObject *)barg;
        if (_PyLong_DigitCount(a) <= 2 && _PyLong_DigitCount(b) <= 2) {
            Py_INCREF(a);
            Py_INCREF(b);
            goto simple;
        }
    
        /* Initial reduction: make sure that 0 <= b <= a. */
        a = (PyLongObject *)long_abs(a);
        if (a == NULL)
            return NULL;
        b = (PyLongObject *)long_abs(b);
        if (b == NULL) {
            Py_DECREF(a);
            return NULL;
        }
        if (long_compare(a, b) < 0) {
            r = a;
            a = b;
            b = r;
        }
        /* We now own references to a and b */
    
        Py_ssize_t size_a, size_b, alloc_a, alloc_b;
        alloc_a = _PyLong_DigitCount(a);
        alloc_b = _PyLong_DigitCount(b);
        /* reduce until a fits into 2 digits */
        while ((size_a = _PyLong_DigitCount(a)) > 2) {
            nbits = bit_length_digit(a->long_value.ob_digit[size_a-1]);
            /* extract top 2*PyLong_SHIFT bits of a into x, along with
               corresponding bits of b into y */
            size_b = _PyLong_DigitCount(b);
            assert(size_b <= size_a);
            if (size_b == 0) {
                if (size_a < alloc_a) {
                    r = (PyLongObject *)_PyLong_Copy(a);
                    Py_DECREF(a);
                }
                else
                    r = a;
                Py_DECREF(b);
                Py_XDECREF(c);
                Py_XDECREF(d);
                return (PyObject *)r;
            }
            x = (((twodigits)a->long_value.ob_digit[size_a-1] << (2*PyLong_SHIFT-nbits)) |
                 ((twodigits)a->long_value.ob_digit[size_a-2] << (PyLong_SHIFT-nbits)) |
                 (a->long_value.ob_digit[size_a-3] >> nbits));
    
            y = ((size_b >= size_a - 2 ? b->long_value.ob_digit[size_a-3] >> nbits : 0) |
                 (size_b >= size_a - 1 ? (twodigits)b->long_value.ob_digit[size_a-2] << (PyLong_SHIFT-nbits) : 0) |
                 (size_b >= size_a ? (twodigits)b->long_value.ob_digit[size_a-1] << (2*PyLong_SHIFT-nbits) : 0));
    
            /* inner loop of Lehmer's algorithm; A, B, C, D never grow
               larger than PyLong_MASK during the algorithm. */
            A = 1; B = 0; C = 0; D = 1;
            for (k=0;; k++) {
                if (y-C == 0)
                    break;
                q = (x+(A-1))/(y-C);
                s = B+q*D;
                t = x-q*y;
                if (s > t)
                    break;
                x = y; y = t;
                t = A+q*C; A = D; B = C; C = s; D = t;
            }
    
            if (k == 0) {
                /* no progress; do a Euclidean step */
                if (l_mod(a, b, &r) < 0)
                    goto error;
                Py_SETREF(a, b);
                b = r;
                alloc_a = alloc_b;
                alloc_b = _PyLong_DigitCount(b);
                continue;
            }
    
            /*
              a, b = A*b-B*a, D*a-C*b if k is odd
              a, b = A*a-B*b, D*b-C*a if k is even
            */
            if (k&1) {
                T = -A; A = -B; B = T;
                T = -C; C = -D; D = T;
            }
            if (c != NULL) {
                assert(size_a >= 0);
                _PyLong_SetSignAndDigitCount(c, 1, size_a);
            }
            else if (Py_REFCNT(a) == 1) {
                c = (PyLongObject*)Py_NewRef(a);
            }
            else {
                alloc_a = size_a;
                c = _PyLong_New(size_a);
                if (c == NULL)
                    goto error;
            }
    
            if (d != NULL) {
                assert(size_a >= 0);
                _PyLong_SetSignAndDigitCount(d, 1, size_a);
            }
            else if (Py_REFCNT(b) == 1 && size_a <= alloc_b) {
                d = (PyLongObject*)Py_NewRef(b);
                assert(size_a >= 0);
                _PyLong_SetSignAndDigitCount(d, 1, size_a);
            }
            else {
                alloc_b = size_a;
                d = _PyLong_New(size_a);
                if (d == NULL)
                    goto error;
            }
    
            a_end = a->long_value.ob_digit + size_a;
            b_end = b->long_value.ob_digit + size_b;
    
            /* compute new a and new b in parallel */
            a_digit = a->long_value.ob_digit;
            b_digit = b->long_value.ob_digit;
            c_digit = c->long_value.ob_digit;
            d_digit = d->long_value.ob_digit;
            c_carry = 0;
            d_carry = 0;
            while (b_digit < b_end) {
                c_carry += (A * *a_digit) - (B * *b_digit);
                d_carry += (D * *b_digit++) - (C * *a_digit++);
                *c_digit++ = (digit)(c_carry & PyLong_MASK);
                *d_digit++ = (digit)(d_carry & PyLong_MASK);
                c_carry >>= PyLong_SHIFT;
                d_carry >>= PyLong_SHIFT;
            }
            while (a_digit < a_end) {
                c_carry += A * *a_digit;
                d_carry -= C * *a_digit++;
                *c_digit++ = (digit)(c_carry & PyLong_MASK);
                *d_digit++ = (digit)(d_carry & PyLong_MASK);
                c_carry >>= PyLong_SHIFT;
                d_carry >>= PyLong_SHIFT;
            }
            assert(c_carry == 0);
            assert(d_carry == 0);
    
            Py_INCREF(c);
            Py_INCREF(d);
            Py_DECREF(a);
            Py_DECREF(b);
            a = long_normalize(c);
            b = long_normalize(d);
        }
        Py_XDECREF(c);
        Py_XDECREF(d);
    
    simple:
        assert(Py_REFCNT(a) > 0);
        assert(Py_REFCNT(b) > 0);
    /* Issue #24999: use two shifts instead of ">> 2*PyLong_SHIFT" to avoid
       undefined behaviour when LONG_MAX type is smaller than 60 bits */
    #if LONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
    
        /* a fits into a long, so b must too */
        x = PyLong_AsLong((PyObject *)a);
        y = PyLong_AsLong((PyObject *)b);
    #elif LLONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
        x = PyLong_AsLongLong((PyObject *)a);
        y = PyLong_AsLongLong((PyObject *)b);
    #else
    # error "_PyLong_GCD"
    #endif
        x = Py_ABS(x);
        y = Py_ABS(y);
        Py_DECREF(a);
        Py_DECREF(b);
    
        /* usual Euclidean algorithm for longs */
        while (y != 0) {
            t = y;
            y = x % y;
            x = t;
        }
    #if LONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
        return PyLong_FromLong(x);
    #elif LLONG_MAX >> PyLong_SHIFT >> PyLong_SHIFT
        return PyLong_FromLongLong(x);
    #else
    # error "_PyLong_GCD"
    #endif
    
    error:
        Py_DECREF(a);
        Py_DECREF(b);
        Py_XDECREF(c);
        Py_XDECREF(d);
        return NULL;
    }