I want to compare two Sympy expressions and know if they are equivalent with some symbolic parameters. A first approah is to use a polynom that works well but it fails with complex expressions. Here is the code:
import sympy
def sym_expr_eq(expr1, expr2, symbols = []):
expr1 = sympy.expand(expr1)
expr2 = sympy.expand(expr2)
poly1 = sympy.Poly(expr1, *symbols)
poly2 = sympy.Poly(expr2, *symbols)
if (poly1 == None or poly2 == None):
return False
if (set(poly1.monoms()) != set(poly2.monoms())):
return False
for i in range(0, len(poly1.coeffs())):
if (poly1.coeffs()[i].is_Number and poly2.coeffs()[i].is_Number):
if (poly1.coeffs()[i] != poly2.coeffs()[i]):
return False
return True
def test_sym_expr_eq_1():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = 1
expr2 = 2
assert(not sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_2():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = 1
expr2 = d
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_3():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = b + 1
expr2 = d - 2
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_4():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = a*x + b
expr2 = c*x + d
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_5():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = a*x + b
expr2 = c*y + d
assert(not sym_expr_eq(expr1, expr2, [x, y]))
def test_sym_expr_eq_6():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = (1+a)*x + b + 3
expr2 = (2+c)*x + d
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_7():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = x + a*x + b + 3
expr2 = 2*x + c*x + d - 1
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_8():
a, b, c, d, e, f, x, y = sympy.symbols('a b c d e f x y')
expr1 = a*x**2 + 3*x + c
expr2 = d*x**2 + 2*x + f
assert(not sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_9():
a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
expr1 = sympy.sympify("a*log(b*x+c)+d")
expr2 = sympy.sympify("e+log(g+x*h)*f")
assert(sym_expr_eq(expr1, expr2, [x]))
All the tests pass except the 9th (sympy.polys.polyerrors.PolynomialError exception is raised) and I don't know how to do. Maybe treat the expression as a polynom of complex functions can work but I don't know how to do too.
I added some other tests and improved your code and all tests passed :
- test_sym_expr_eq_14 is tricky, I don't know yet why it works sometimes and it fails other times.
I think these cases are exhaustive, I don't see other cases for my goal for the moment.
import sympy
def symplify_sym_expr(expr, symbols):
expr = sympy.collect(expr, expr_terms(expr, symbols))
for arg in expr.args:
expr = expr.subs(arg, symplify_sym_expr(arg, symbols))
expr = sympy.simplify(expr)
if (expr.is_Add):
nums = []
ws_exprs = []
wos_exprs = []
for arg in expr.args:
if arg.is_Number:
nums.append(arg)
else:
found = False
for s in symbols:
if (arg.has(s)):
found = True
break
if (found):
ws_exprs.append(arg)
else:
wos_exprs.append(arg)
if (len(wos_exprs)):
for n in nums:
expr = expr.subs(n, 0)
for e in wos_exprs[1:]:
expr = expr.subs(e, 0)
expr = expr.subs(wos_exprs[0], newSymbol())
expr = sympy.simplify(expr)
elif (expr.is_Mul):
nums = []
ws_exprs = []
wos_exprs = []
for arg in expr.args:
if arg.is_Number:
nums.append(arg)
else:
found = False
for s in symbols:
if (arg.has(s)):
found = True
break
if (found):
ws_exprs.append(arg)
else:
wos_exprs.append(arg)
if (len(wos_exprs)):
for n in nums:
expr = expr.subs(n, 1)
for e in wos_exprs[1:]:
expr = expr.subs(e, 1)
expr = expr.subs(wos_exprs[0], newSymbol())
expr = sympy.simplify(expr)
return expr
def same_ast_structure(expr1, expr2, symbols):
if type(expr1) != type(expr2):
if not ((expr1.is_Number and expr2.is_Symbol)
or (expr2.is_Number and expr1.is_Symbol)):
return False
else:
if (expr1.is_Symbol):
if (expr1 in symbols or expr2 in symbols):
if (expr1 != expr2):
return False
elif (expr1.is_number):
if (expr1 != expr2):
return False
if len(expr1.args) != len(expr2.args):
return False
args1 = expr1.args
if (len(args1)):
e = [expr_terms(arg, symbols) for arg in args1]
e, args1 = list(zip(*sorted(zip(e, args1), key = lambda x: str(x[0]))))
args2 = expr2.args
if (len(args2)):
e = [expr_terms(arg, symbols) for arg in args2]
e, args2 = list(zip(*sorted(zip(e, args2), key = lambda x: str(x[0]))))
return all(same_ast_structure(a1, a2, symbols) for a1, a2 in zip(args1, args2))
def expr_terms(expr, symbols):
terms = []
if (expr.is_Add):
return [expr_terms(a, symbols)[0] for a in expr.args]
elif (expr.is_Mul):
numbers = []
syms = []
has_expr = False
for arg in expr.args:
if arg.is_Number:
numbers.append(arg)
elif arg.is_Symbol:
syms.append(arg)
else:
has_expr = True
e = expr
l = len(syms)
syms = list(set(syms) - set(symbols))
if (has_expr):
for n in numbers:
e = e.subs(n, 1)
for s in syms:
e = e.subs(s, 1)
e = sympy.simplify(e)
else:
if (l):
for n in numbers:
e = e.subs(n, 1)
for s in syms:
e = e.subs(s, 1)
e = sympy.simplify(e)
terms = [e]
else:
terms.append(expr)
return terms
def sym_expr_eq(a, b, symbols = []):
a = sympy.expand(sympy.sympify(a))
a = symplify_sym_expr(a, symbols)
a = sympy.expand(a)
b = sympy.expand(sympy.sympify(b))
b = symplify_sym_expr(b, symbols)
b = sympy.expand(b)
return same_ast_structure(a, b, symbols)
def test_sym_expr_eq_1():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = 1
expr2 = 2
assert(not sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_2():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = 1
expr2 = d
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_3():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = b + 1
expr2 = d - 2
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_4():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = a*x + b
expr2 = c*x + d
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_5():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = a*x + b
expr2 = c*y + d
assert(not sym_expr_eq(expr1, expr2, [x, y]))
def test_sym_expr_eq_6():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = (1+a)*x + b + 3
expr2 = (2+c)*x + d
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_7():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = x + a*x + b + 3
expr2 = 2*x + c*x + d - 1
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_8():
a, b, c, d, e, f, x, y = sympy.symbols('a b c d e f x y')
expr1 = a*x**2 + 3*x + c
expr2 = d*x**2 + 2*x + f
assert(not sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_9():
a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
expr1 = sympy.sympify("a*log(b*x+c)+d")
expr2 = sympy.sympify("e+log(g+x*h)*f")
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_10():
a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
expr1 = sympy.sympify("a*log(b*x+c)+d")
expr2 = sympy.sympify("e+log(g+y*h)*f")
assert(not sym_expr_eq(expr1, expr2, [x, y]))
def test_sym_expr_eq_11():
a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
expr1 = sympy.sympify("a*log(b*x+c+2*x)+d")
expr2 = sympy.sympify("e+log(g+x*(1+h))*f")
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_12():
a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
expr1 = sympy.sympify("a*x+2*d")
expr2 = sympy.sympify("b*c*x+e*f")
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_13():
a, b, c, d, e, f, g, h, x, y = sympy.symbols('a b c d e f g h x y')
expr1 = sympy.sympify("a*log(x)+2*d")
expr2 = sympy.sympify("b*log(x)+2*log(x)+e*f")
assert(sym_expr_eq(expr1, expr2, [x]))
def test_sym_expr_eq_14():
expr1 = sympy.sympify("_71*sin(_0*x + _1) + _72*exp(_62)*exp(_61*x) + _73*exp(_62)*exp(_61*x)*sin(_0*x + _1) + _74")
expr2 = sympy.sympify("i*(a*sin(b*x+c)+d)*(e*exp(f*x+g)+h)+j")
assert(sr.sym_expr_eq(expr1, expr2, [sympy.Symbol("x")]))
def test_sym_expr_eq_15():
a, b, c, d, x, y = sympy.symbols('a b c d x y')
expr1 = a * x + b * y + c
expr2 = 2 * x + 3 * y + 4
assert(sr.sym_expr_eq(expr1, expr2, [x, y]))