pythonalgorithmperformancemathcollatz

Maximizing Efficiency of Collatz Conjecture Program Python


My question is very simple.

I wrote this program for pure entertainment. It takes a numerical input and finds the length of every Collatz Sequence up to and including that number.

I want to make it faster algorithmically or mathematically (i.e. I know I could make it faster by running multiple versions parallel or by writing it in C++, but where's the fun in that?).

Any and all help is welcome, thanks!

EDIT: Code further optimized with the help of dankal444

from matplotlib import pyplot as plt
import numpy as np
import numba as nb

# Get Range to Check
top_range = int(input('Top Range: '))

@nb.njit('int64[:](int_)')
def collatz(top_range):
    # Initialize mem
    mem = np.zeros(top_range + 1, dtype = np.int64)
    for start in range(2, top_range + 1):
        # If mod4 == 1: (3x + 1)/4
        if start % 4 == 1:
            mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
        
        # If 4mod == 3: 3(3x + 1) + 1 and continue
        elif start % 4 == 3:
            num = start + (start >> 1) + 1
            num += (num >> 1) + 1
            count = 4

            while num >= start:
                if num % 2:
                    num += (num >> 1) + 1
                    count += 2
                else:
                    num //= 2
                    count += 1
            mem[start] = mem[num] + count

        # If 4mod == 2 or 0: x/2
        else:
            mem[start] = mem[(start // 2)] + 1

    return mem

mem = collatz(top_range)

# Plot each starting number with the length of it's sequence
plt.scatter([*range(1, len(mem) + 1)], mem, color = 'black', s = 1)
plt.show()

Solution

  • Applying numba on your code does help by much.

    I removed tqdm since it does not help with performance.

    import time
    from matplotlib import pyplot as plt
    from tqdm import tqdm
    
    import numpy as np
    import numba as nb
    @nb.njit('int64[:](int_)')
    def collatz2(top_range):
        mem = np.zeros(top_range + 1, dtype=np.int64)
        for start in range(2, top_range + 1):
            # If mod(4) == 1: Value 2 or 3 Cached
            if start % 4 == 1:
                mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
            # If mod(4) == 3: Use Algorithm
            elif start % 4 == 3:
                num = start
                count = 0
                while num >= start:
                    if num % 2:
                        num += (num >> 1) + 1
                        count += 2
                    else:
                        num //= 2
                        count += 1
                mem[start] = mem[num] + count
            # If mod(4) == 2 or 4: Value 1 Cached
            else:
                mem[start] = mem[(start // 2)] + 1
        return mem
    
    
    def collatz(top_range):
        mem = [0] * (top_range + 1)
        for start in range(2, top_range + 1):
            # If mod(4) == 1: Value 2 or 3 Cached
            if start % 4 == 1:
                mem[start] = mem[(start + (start >> 1) + 1) // 2] + 3
            # If mod(4) == 3: Use Algorithm
            elif start % 4 == 3:
                num = start
                count = 0
                while num >= start:
                    if num % 2:
                        num += (num >> 1) + 1
                        count += 2
                    else:
                        num //= 2
                        count += 1
                mem[start] = mem[num] + count
            # If mod(4) == 2 or 4: Value 1 Cached
            else:
                mem[start] = mem[(start // 2)] + 1
        return mem
    
    # profiling here
    def main():
    
        top_range = 1_000_000
        mem = collatz(top_range)
        mem2 = collatz2(top_range)
        assert np.allclose(np.array(mem), mem2)
    
    
    

    For top_range = 1_000 optimized function is ~100x faster. For top_range = 1_000_000, the optimized function is about 600x faster:

        79                                           def main():
        81         1          3.0      3.0      0.0      top_range = 1_000_000
        83         1   24633045.0 24633045.0     98.7      mem = collatz(top_range)
        85         1      39311.0  39311.0      0.2      mem2 = collatz2(top_range)