sympy

Do not expand sub-expressions until required


I wish to derivate a complex function. Unfortunately, as I construct the expression to be derivated, it expands all sub-expressions down to the symbols used immediately. As such, once I differenciate the final expression, it baloons to an unmanageable size. I have tried manually replacing some sub-expressions back to their prior form in the final result, but due to restructuring during construction of the expression this misses a lot of opportunities to simplify.

Here's only the first stage of what I want to derivate and it's already hopelessly bloated even after trying to substitute some expressions used to construct it.

from sympy import *

aX, aY, aZ = symbols('aX aY aZ')
rotInc = Matrix(3,1,[aX, aY, aZ])
theta = sqrt((rotInc.T @ rotInc)[0,0])
incQuat = Quaternion.from_axis_angle(rotInc/theta, theta*2)

qX, qY, qZ, qW = symbols('qX qY qZ qW')
baseQuat = Quaternion(qW, qX, qY, qZ)
poseQuat = incQuat * baseQuat

d4 = diff(poseQuat, aX)
d4s = d4.subs({
    incQuat.a: symbols('iW'),
    incQuat.b: symbols('iX'),
    incQuat.c: symbols('iY'),
    incQuat.d: symbols('iZ'),
    theta: symbols('theta')
})

I know of cse (common subexpression elimination) and that shows me there is some kind of system of keeping named sub-expressions around. I'd prefer if sympy built such a structure like cse returns while I am constructing the expression, and only replaces a sub-expression with it's components as necessary - e.g. during differenciation. That would keep most of these symbols around in the final differeniation, resulting in a cleaner output that is immediately useable.

Is there such a mode / way of constructing expressions in sympy or something else that helps me keep the final expression simple?


Solution

  • I thought this is in SymPy somewhere, but perhaps the cse use in preparing compiled function (or something like that) is where it is used. But along your idea (and probably redundant of thinking in this issue):

    from sympy import symbols, Function, Derivative, cse
    
    def dep(r, x):
        reps = {}
        for v, e in r:
            if e.xreplace(reps).has(x):
                reps[v] = Function(str(v))(x)
        return reps
    
    def differentiate_with_cse(expr, x, backsub=False):
        r, e = cse(expr)
        reps = dep(r, x)
        dr = []
        for v, s in r:
            if v in reps:
                f = reps[v]
                df = Derivative(f, x).doit()
                ds = s.xreplace(reps).diff(x)
                dr.append((df, ds))
        dexpr = e[0].subs(reps).diff(x).subs(dr).expand()
        for k, v in reversed(reps.items()):
            dexpr = dexpr.subs(v, k)
        return dexpr.subs(list(reversed(r))) if backsub else dexpr,dict(r)
    
    ...your code up to derivative
    
    >>> deriv, pats = differentiate_with_cse(poseQuat, aX)
    >>> deriv.simplify()
    (-qX*x7) + qW*x7*i + (-qZ*x7)*j + qY*x7*k
    >>> pats
    {x0: aX**2, x1: aY**2, x10: aZ*x7, x2: aZ**2, x3: x0 + x1 + x2, x4: sqrt(x3), x5: cos(x4), x6: 1/x3, x7: sin(x4)/(x4*sqrt(x0*x6 + x1*x6 + x2*x6)), x8: aX*x7, x9: aY*x7})