pythonpython-3.xalgorithmackermann

Theoretically can the Ackermann function be optimized?


I am wondering if there can be a version of Ackermann function with better time complexity than the standard variation.

This is not a homework and I am just curious. I know the Ackermann function doesn't have any practical use besides as a performance benchmark, because of the deep recursion. I know the numbers grow very large very quickly, and I am not interested in computing it.

Even though I use Python 3 and the integers won't overflow, I do have finite time, but I have implemented a version of it myself according to the definition found on Wikipedia, and computed the output for extremely small values, just to make sure the output is correct.

enter image description here

def A(m, n):
    if not m:
        return n + 1
    return A(m - 1, A(m, n - 1)) if n else A(m - 1, 1)

The above code is a direct translation of the image, and is extremely slow, I don't know how it can be optimized, is it impossible to optimize it?

One thing I can think of is to memoize it, but the recursion runs backwards, each time the function is recursively called the arguments were not encountered before, each successive function call the arguments decrease rather than increase, therefore each return value of the function needs to be calculated, memoization doesn't help when you call the function with different arguments the first time.

Memoization can only help if you call it with the same arguments again, it won't compute the results and will retrieve cached result instead, but if you call the function with any input with (m, n) >= (4, 2) it will crash the interpreter regardless.

I also implemented another version according to this answer:

def ack(x, y):
    for i in range(x, 0, -1):
        y = ack(i, y - 1) if y else 1
    return y + 1

But it is actually slower:

In [2]: %timeit A(3, 4)
1.3 ms ± 9.75 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [3]: %timeit ack(3, 4)
2 ms ± 59.9 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Theoretically can Ackermann function be optimized? If not, can it be definitely proven that its time complexity cannot decrease?


I have just tested A(3, 9) and A(4, 1) will crash the interpreter, and the performance of the two functions for A(3, 8):

In [2]: %timeit A(3, 8)
432 ms ± 4.63 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

In [3]: %timeit ack(3, 8)
588 ms ± 10.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

I did some more experiments:

from collections import Counter
from functools import cache

c = Counter()
def A1(m, n):
    c[(m, n)] += 1
    if not m:
        return n + 1
    return A(m - 1, A(m, n - 1)) if n else A(m - 1, 1)

def test(m, n):
    c.clear()
    A1(m, n)
    return c

The arguments indeed repeat.

But surprisingly caching doesn't help at all:

In [9]: %timeit Ackermann = cache(A); Ackermann(3, 4)
1.3 ms ± 10.1 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Caching only helps when the function is called with the same arguments again, as explained:

In [14]: %timeit Ackermann(3, 2)
101 ns ± 0.47 ns per loop (mean ± std. dev. of 7 runs, 10,000,000 loops each)

I have tested it with different arguments numerous times, and it always gives the same efficiency boost (which is none).


Solution

  • Solution

    I recently wrote a bunch of solutions based on the same paper that templatetypedef mentioned. Many use generators, one for each m-value, yielding the values for n=0, n=1, n=2, etc. This one might be my favorite:

    def A_Stefan_generator_stack3(m, n):
        def a(m):
            if not m:
                yield from count(1)
            x = 1
            for i, ai in enumerate(a(m-1)):
                if i == x:
                    x = ai
                    yield x
        return next(islice(a(m), n, None))
    

    Explanation

    Consider the generator a(m). It yields A(m,0), A(m,1), A(m,2), etc. The definition of A(m,n) uses A(m-1, A(m, n-1)). So a(m) at its index n yields A(m,n), computed like this:

    Benchmark

    Here are times for computing all A(m,n) for m≤3 and n≤17, also including templatetypedef's solution:

     1325 ms  A_Stefan_row_class
     1228 ms  A_Stefan_row_lists
      544 ms  A_Stefan_generators
     1363 ms  A_Stefan_paper
      459 ms  A_Stefan_generators_2
      866 ms  A_Stefan_m_recursion
      704 ms  A_Stefan_function_stack
      468 ms  A_Stefan_generator_stack
      945 ms  A_Stefan_generator_stack2
      582 ms  A_Stefan_generator_stack3
      467 ms  A_Stefan_generator_stack4
     1652 ms  A_templatetypedef
    

    Note: Even faster (much faster) solutions using math insights/formulas are possible, see my comment and pts's answer. I intentionally didn't do that, as I was interested in coding techniques, for avoiding deep recursion and avoiding re-calculation. I got the impression that that's also what the question/OP wanted, and they confirmed that now (under a deleted answer, visible if you have enough reputation).

    Code

    def A_Stefan_row_class(m, n):
        class A0:
            def __getitem__(self, n):
                return n + 1
        class A:
            def __init__(self, a):
                self.a = a
                self.n = 0
                self.value = a[1]
            def __getitem__(self, n):
                while self.n < n:
                    self.value = self.a[self.value]
                    self.n += 1
                return self.value
        a = A0()
        for _ in range(m):
            a = A(a)
        return a[n]
    
    
    from collections import defaultdict
    
    def A_Stefan_row_lists(m, n):
        memo = defaultdict(list)
        def a(m, n):
            if not m:
                return n + 1
            if m not in memo:
                memo[m] = [a(m-1, 1)]
            Am = memo[m]
            while len(Am) <= n:
                Am.append(a(m-1, Am[-1]))
            return Am[n]
        return a(m, n)
    
    
    from itertools import count
    
    def A_Stefan_generators(m, n):
        a = count(1)
        def up(a, x=1):
            for i, ai in enumerate(a):
                if i == x:
                    x = ai
                    yield x
        for _ in range(m):
            a = up(a)
        return next(up(a, n))
    
    
    def A_Stefan_paper(m, n):
        next = [0] * (m + 1)
        goal = [1] * m + [-1]
        while True:
            value = next[0] + 1
            transferring = True
            i = 0
            while transferring:
                if next[i] == goal[i]:
                    goal[i] = value
                else:
                    transferring = False
                next[i] += 1
                i += 1
            if next[m] == n + 1:
                return value
    
    
    def A_Stefan_generators_2(m, n):
        def a0():
            n = yield
            while True:
                n = yield n + 1
        def up(a):
            next(a)
            a = a.send
            i, x = -1, 1
            n = yield
            while True:
                while i < n:
                    x = a(x)
                    i += 1
                n = yield x
        a = a0()
        for _ in range(m):
            a = up(a)
        next(a)
        return a.send(n)
    
    
    def A_Stefan_m_recursion(m, n):
        ix = [None] + [(-1, 1)] * m
        def a(m, n):
            if not m:
                return n + 1
            i, x = ix[m]
            while i < n:
                x = a(m-1, x)
                i += 1
            ix[m] = i, x
            return x
        return a(m, n)
    
    
    def A_Stefan_function_stack(m, n):
        def a(n):
            return n + 1
        for _ in range(m):
            def a(n, a=a, ix=[-1, 1]):
                i, x = ix
                while i < n:
                    x = a(x)
                    i += 1
                ix[:] = i, x
                return x
        return a(n)
    
    
    from itertools import count, islice
    
    def A_Stefan_generator_stack(m, n):
        a = count(1)
        for _ in range(m):
            a = (
                x
                for a, x in [(a, 1)]
                for i, ai in enumerate(a)
                if i == x
                for x in [ai]
            )
        return next(islice(a, n, None))
    
    
    from itertools import count, islice
    
    def A_Stefan_generator_stack2(m, n):
        a = count(1)
        def up(a):
            i, x = 0, 1
            while True:
                i, x = x+1, next(islice(a, x-i, None))
                yield x
        for _ in range(m):
            a = up(a)
        return next(islice(a, n, None))
    
    
    def A_Stefan_generator_stack3(m, n):
        def a(m):
            if not m:
                yield from count(1)
            x = 1
            for i, ai in enumerate(a(m-1)):
                if i == x:
                    x = ai
                    yield x
        return next(islice(a(m), n, None))
    
    
    def A_Stefan_generator_stack4(m, n):
        def a(m):
            if not m:
                return count(1)
            return (
                x
                for x in [1]
                for i, ai in enumerate(a(m-1))
                if i == x
                for x in [ai]
            )
        return next(islice(a(m), n, None))
    
    
    def A_templatetypedef(i, n):
        positions = [-1] * (i + 1)
        values = [0] + [1] * i
        
        while positions[i] != n:       
            values[0]    += 1
            positions[0] += 1
                
            j = 1
            while j <= i and positions[j - 1] == values[j]:
                values[j] = values[j - 1]
                positions[j] += 1
                j += 1
    
        return values[i]
    
    
    funcs = [
        A_Stefan_row_class,
        A_Stefan_row_lists,
        A_Stefan_generators,
        A_Stefan_paper,
        A_Stefan_generators_2,
        A_Stefan_m_recursion,
        A_Stefan_function_stack,
        A_Stefan_generator_stack,
        A_Stefan_generator_stack2,
        A_Stefan_generator_stack3,
        A_Stefan_generator_stack4,
        A_templatetypedef,
    ]
    
    N = 18
    args = (
        [(0, n) for n in range(N)] +
        [(1, n) for n in range(N)] +
        [(2, n) for n in range(N)] +
        [(3, n) for n in range(N)]
    )
    
    from time import time
    
    def print(*args, print=print, file=open('out.txt', 'w')):
        print(*args)
        print(*args, file=file, flush=True)
        
    expect = none = object()
    for _ in range(3):
      for f in funcs:
        t = time()
        result = [f(m, n) for m, n in args]
        # print(f'{(time()-t) * 1e3 :5.1f} ms ', f.__name__)
        print(f'{(time()-t) * 1e3 :5.0f} ms ', f.__name__)
        if expect is none:
            expect = result
        elif result != expect:
            raise Exception(f'{f.__name__} failed')
        del result
      print()