pythonmathsympymathematical-optimization

How can I shortcut the unlimited-time SymPy expansion of all (a + b*z**k) given some z(n)?


Let me put some context first.

You see, I have a theory that multiplying all (a + b*z**k), for a determined z based on n, and some exponents k, it could return a**n+b**n. This is my brand new Cori theorem (don't ask who's Cori), and I kind of needed to prove it. Let's understand what each thing means:

So, my goal is to:

  1. Calculate all possibles ks, and the unique z
  2. Map every k to linear terms a + b*z**k
  3. Multiply all those linear terms

The challenge isn't writing the code, or fixing a wierd bug. The problem arises when you set n=7, which is the only reason I've built this code. I think you should see the problem yourself, so here you go:

import sympy as sp
from sympy.solvers import solve

def cycles(n):
    # a and b will be used for product
    # x will be used for unit calculation
    a, b, x = sp.symbols('a b x')
    if n % 2 == 0:
        exponents = [2*k+1 for k in range(0, n)]
        units = solve(x**n + 1, x)
    else:
        exponents = list(range(0, n))
        units = solve(x**n - 1, x)
    #z = sp.exp(sp.I * sp.pi * (1 if n % 2 == 0 else 2) / N)
    z = units[n-2]

    factors = [a + b*z**k for k in exponents]
    product = sp.expand(sp.Mul(* factors))
    return product

if __name__ == "__main__":
    for i in range(2, 10):
        c = cycles(i)
        print(c)
        """
        TODO: replace sp.simplify for something that works for big n
        """
        print(sp.simplify(c))

You should see two a**i + b**i for each i in range(2, 7), a really big expression, and a freeze after that. The endless waiting is bigger than an HOUR, and even if you skip 7, any biggest number makes the wait LONGER. I've never seen it return something, I'm not even sure if it'll return the a**7 + b**7 I'm expecting, maybe it's better to check my theory by hand? But I'm a programmer, so...

Is there any way to speed up this? Or, at least check if it's even doing something?


Addendum

Edit: I was so bored waiting that I proved we should always expect a**i + b**i. We should, but I don't. Yes, I'm still waiting...

Another edit: So I verified the math, and it's all great. Let me casually thank @ti7 for the better title, and @Oscar Benjamin for cracking the case. I had to tidy up some things to better fit the general formula, and I know it's messy, but it shows the structure I wanted. So... Here it is:

import sympy as sp

def cycles(n):
    
    # Use 't' as a variable
    t = sp.symbols('t', real=True)
    exponents = list(range(0, n))
    
    # Get the roots of unity
    z =  sp.exp(2 * sp.I * sp.pi / n)
    z2 = sp.exp(    sp.I * sp.pi / n)

    # Get their algebraic field
    z =  sp.Poly(z , t, domain=sp.QQ.algebraic_field(z ))
    z2 = sp.Poly(z2, t, domain=sp.QQ.algebraic_field(z2))
    
    # Multiply them
    factors = [t + z2*z**k for k in exponents]
    return sp.prod(factors).as_expr()

if __name__ == "__main__":
    for i in range(2, 10):
        print('===', i, '===')
        print(cycles(i), '\n')

Thanks again, now I have a theorem I can say it's true :D


Solution

  • You should use Poly for more structured expression manipulation which is faster:

    import sympy as sp
    from sympy.solvers import solve
    
    def cycles(n):
        # a and b will be used for product
        # x will be used for unit calculation
        a, b, x = sp.symbols('a b x')
        if n % 2 == 0:
            exponents = [2*k+1 for k in range(0, n)]
        else:
            exponents = list(range(0, n))
        z = sp.exp(sp.I * sp.pi * (1 if n % 2 == 0 else 2) / n)
    
        domain = sp.QQ.algebraic_field(z)
        z = sp.Poly(z, a, b, domain=domain)
        factors = [a + b*z**k for k in exponents]
        return sp.prod(factors).as_expr()
    
    if __name__ == "__main__":
        for i in range(2, 10):
            c = cycles(i)
            print(c)
            """
            TODO: replace sp.simplify for something that works for big n
            """
            print(sp.simplify(c))
    

    Output:

    $ python t.py 
    a**2 + b**2
    a**2 + b**2
    a**3 + b**3
    a**3 + b**3
    a**4 + b**4
    a**4 + b**4
    a**5 + b**5
    a**5 + b**5
    a**6 + b**6
    a**6 + b**6
    a**7 + b**7
    a**7 + b**7
    a**8 + b**8
    a**8 + b**8
    a**9 + b**9
    a**9 + b**9