pythonsympysymbolic-math

Compare SymPy expressions modulo variable remapping


I want to compare two Sympy expressions and know if they are equivalent with some symbolic parameters. A first approah is to use a polynom that works well but it fails with complex expressions. Here is the code:

import sympy

def sym_expr_eq(expr1, expr2, symbols = []):
    expr1 = sympy.expand(expr1)
    expr2 = sympy.expand(expr2)

    poly1 = sympy.Poly(expr1, *symbols)
    poly2 = sympy.Poly(expr2, *symbols)

    if (poly1 == None or poly2 == None):
        return False

    if (set(poly1.monoms()) != set(poly2.monoms())):
        return False

    for i in range(0, len(poly1.coeffs())):
        if (poly1.coeffs()[i].is_Number and poly2.coeffs()[i].is_Number):
            if (poly1.coeffs()[i] != poly2.coeffs()[i]):
                return False

    return True


def test_sym_expr_eq_1():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')

    expr1 = 1
    expr2 = 2

    assert(not sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_2():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')

    expr1 = 1
    expr2 = d

    assert(sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_3():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
    expr1 = b + 1
    expr2 = d - 2

    assert(sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_4():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')

    expr1 = a*x + b
    expr2 = c*x + d

    assert(sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_5():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')

    expr1 = a*x + b
    expr2 = c*y + d

    assert(not sym_expr_eq(expr1, expr2, [x, y]))

def test_sym_expr_eq_6():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')

    expr1 = (1+a)*x + b + 3
    expr2 = (2+c)*x + d

    assert(sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_7():
    a, b, c, d, x, y = sympy.symbols('a b c d x y')

    expr1 = x + a*x + b + 3
    expr2 = 2*x + c*x + d - 1

    assert(sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_8():
    a, b, c, d, e, f, x, y = sympy.symbols('a b c d e f x y')

    expr1 = a*x**2 + 3*x + c
    expr2 = d*x**2 + 2*x + f

    assert(not sym_expr_eq(expr1, expr2, [x]))

def test_sym_expr_eq_9():
    a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')

    expr1 = sympy.sympify("a*log(b*x+c)+d")
    expr2 = sympy.sympify("e+log(g+x*h)*f")

    assert(sym_expr_eq(expr1, expr2, [x]))

All the tests pass except the 9th (sympy.polys.polyerrors.PolynomialError exception is raised) and I don't know how to do. Maybe treat the expression as a polynom of complex functions can work but I don't know how to do too.


Solution

  • I added some other tests and improved your code and all tests passed :
    - test_sym_expr_eq_14 is tricky, I don't know yet why it works sometimes and it fails other times.
    I think these cases are exhaustive, I don't see other cases for my goal for the moment.

    import sympy
    
    def symplify_sym_expr(expr, symbols): 
        expr = sympy.collect(expr, expr_terms(expr, symbols))
        
        for arg in expr.args:
            expr = expr.subs(arg, symplify_sym_expr(arg, symbols))
        
        expr = sympy.simplify(expr)
    
        if (expr.is_Add):
            nums = []
            ws_exprs = []
            wos_exprs = []
            
            for arg in expr.args:
                if arg.is_Number:
                    nums.append(arg)
                else:
                    found = False
                    
                    for s in symbols:
                        if (arg.has(s)):
                            found = True
                            break
                    
                    if (found):    
                        ws_exprs.append(arg)
                    else:
                        wos_exprs.append(arg)
    
            if (len(wos_exprs)):
                for n in nums:
                    expr = expr.subs(n, 0)
    
                for e in wos_exprs[1:]:
                    expr = expr.subs(e, 0)
                
                expr = expr.subs(wos_exprs[0], newSymbol())
                
                expr = sympy.simplify(expr)
        elif (expr.is_Mul):
            nums = []
            ws_exprs = []
            wos_exprs = []
            
            for arg in expr.args:
                if arg.is_Number:
                    nums.append(arg)
                else:
                    found = False
                    
                    for s in symbols:
                        if (arg.has(s)):
                            found = True
                            break
                    
                    if (found):    
                        ws_exprs.append(arg)
                    else:
                        wos_exprs.append(arg)
    
            if (len(wos_exprs)):
                for n in nums:
                    expr = expr.subs(n, 1)
    
                for e in wos_exprs[1:]:
                    expr = expr.subs(e, 1)
                
                expr = expr.subs(wos_exprs[0], newSymbol())
                
                expr = sympy.simplify(expr)
    
        return expr
    
    def same_ast_structure(expr1, expr2, symbols):
        if type(expr1) != type(expr2):
            if not ((expr1.is_Number and expr2.is_Symbol)
                    or (expr2.is_Number and expr1.is_Symbol)):
                return False
        else:
            if (expr1.is_Symbol):
                if (expr1 in symbols or expr2 in symbols):
                    if (expr1 != expr2):
                        return False
            elif (expr1.is_number):
                if (expr1 != expr2):
                    return False
    
        if len(expr1.args) != len(expr2.args):
            return False
    
        args1 = expr1.args
        
        if (len(args1)):
            e = [expr_terms(arg, symbols) for arg in args1]
            e, args1 = list(zip(*sorted(zip(e, args1), key = lambda x: str(x[0]))))
            
        args2 = expr2.args
        
        if (len(args2)):
            e = [expr_terms(arg, symbols) for arg in args2]
            e, args2 = list(zip(*sorted(zip(e, args2), key = lambda x: str(x[0]))))
    
        return all(same_ast_structure(a1, a2, symbols) for a1, a2 in zip(args1, args2))
    
    def expr_terms(expr, symbols):
        terms = []
    
        if (expr.is_Add):
            return [expr_terms(a, symbols)[0] for a in expr.args]
        elif (expr.is_Mul):
            numbers = []
            syms = []
            has_expr = False
    
            for arg in expr.args:
                if arg.is_Number:
                    numbers.append(arg)
                elif arg.is_Symbol:
                    syms.append(arg)
                else:
                    has_expr = True
    
            e = expr
            l = len(syms)
            syms = list(set(syms) - set(symbols))
    
            if (has_expr):
                for n in numbers:
                    e = e.subs(n, 1)
    
                for s in syms:
                    e = e.subs(s, 1)
    
                e = sympy.simplify(e)
            else:
                if (l):
                    for n in numbers:
                        e = e.subs(n, 1)
    
                    for s in syms:
                        e = e.subs(s, 1)
    
                    e = sympy.simplify(e)
    
            terms = [e]
        else:
            terms.append(expr)
    
        return terms
    
    def sym_expr_eq(a, b, symbols = []):
        a = sympy.expand(sympy.sympify(a))
        a = symplify_sym_expr(a, symbols)
        a = sympy.expand(a)
        b = sympy.expand(sympy.sympify(b))
        b = symplify_sym_expr(b, symbols)
        b = sympy.expand(b)
    
        return same_ast_structure(a, b, symbols)
    
    def test_sym_expr_eq_1():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = 1
        expr2 = 2
    
        assert(not sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_2():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = 1
        expr2 = d
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_3():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
        
        expr1 = b + 1
        expr2 = d - 2
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_4():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = a*x + b
        expr2 = c*x + d
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_5():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = a*x + b
        expr2 = c*y + d
    
        assert(not sym_expr_eq(expr1, expr2, [x, y]))
    
    def test_sym_expr_eq_6():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = (1+a)*x + b + 3
        expr2 = (2+c)*x + d
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_7():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = x + a*x + b + 3
        expr2 = 2*x + c*x + d - 1
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_8():
        a, b, c, d, e, f, x, y = sympy.symbols('a b c d e f x y')
    
        expr1 = a*x**2 + 3*x + c
        expr2 = d*x**2 + 2*x + f
    
        assert(not sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_9():
        a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
    
        expr1 = sympy.sympify("a*log(b*x+c)+d")
        expr2 = sympy.sympify("e+log(g+x*h)*f")
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_10():
        a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
    
        expr1 = sympy.sympify("a*log(b*x+c)+d")
        expr2 = sympy.sympify("e+log(g+y*h)*f")
    
        assert(not sym_expr_eq(expr1, expr2, [x, y]))
    
    def test_sym_expr_eq_11():
        a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
    
        expr1 = sympy.sympify("a*log(b*x+c+2*x)+d")
        expr2 = sympy.sympify("e+log(g+x*(1+h))*f")
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_12():
        a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
    
        expr1 = sympy.sympify("a*x+2*d")
        expr2 = sympy.sympify("b*c*x+e*f")
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_13():
        a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
    
        expr1 = sympy.sympify("a*log(x)+2*d")
        expr2 = sympy.sympify("b*log(x)+2*log(x)+e*f")
    
        assert(sym_expr_eq(expr1, expr2, [x]))
    
    def test_sym_expr_eq_14():
        expr1 = sympy.sympify("_71*sin(_0*x + _1) + _72*exp(_62)*exp(_61*x) + _73*exp(_62)*exp(_61*x)*sin(_0*x + _1) + _74")
        expr2 = sympy.sympify("i*(a*sin(b*x+c)+d)*(e*exp(f*x+g)+h)+j")
    
        assert(sr.sym_expr_eq(expr1, expr2, [sympy.Symbol("x")]))
    
    def test_sym_expr_eq_15():
        a, b, c, d, x, y = sympy.symbols('a b c d x y')
    
        expr1 = a * x + b * y  + c
        expr2 = 2 * x + 3 * y + 4
    
        assert(sr.sym_expr_eq(expr1, expr2, [x, y]))