pythonsympysolver

sympy.solve never returns for certain coefficients


The form of the expression (corresponding the the variable called 'expression' below) whose critical points I'm trying to solve for is identical in all cases since the only thing that is changing between success and failure is the values of the exponents on each term.

Here is a minimum reproducible example:

from sympy import symbols, Mul, diff, solve
from decimal import Decimal


class Histogram:
    def __init__(self, bins, pdf):
        self.bins = bins
        self.pdf = pdf


histogram = Histogram(
    bins=[-1, 0, 1],
    pdf=[Decimal('0.454861111111111104943205418749130330979824066162109375'),
         Decimal('0'),
         Decimal('0.545138888888888839545643349993042647838592529296875')]
)


def computation(histogram):

    A = Decimal('1')
    e = Decimal('1')
    M_factor = Decimal('1.0828567056280801')
    x = symbols('x')
    product_terms = []
    profit = lambda cp: (e * M_factor ** cp - e) * x
    commission = lambda cp: (e * M_factor ** cp + e) * Decimal('0.0005') * x

    for candles_profit, freq in zip(histogram.bins, histogram.pdf):
        profit_term = profit(candles_profit)
        commission_term = commission(candles_profit)
        term = ((A + profit_term - commission_term) / A) ** freq
        product_terms.append(term)

    expression = Mul(*product_terms)
    derivative = diff(expression, x)
    solutions = solve(derivative, x)
    print("Solutions:", solutions)


computation(histogram)

The 'expression' variable in the computation function above turns into the following in the debugger:

(1.0 - 0.0774785191289289x)^0.454861111111111 \* (0.0818152772752661x + 1.0)^0.545138888888889

If you create a python file with the example above and run it, it never returns when attempting to find the expression's critical points.

The unique feature I see that may explain the failure is that it has repeating decimals in its exponents, like a rational number like 2/3. To confirm my intuition that this is the source of the error, I revisited the trades that formed the histogram for this symbol: 131 were in the -1 bucket and 157 were in the +1 bucket. 131/288 = 0.454861 with a repeating 1, and 157/288 = 0.545138 with a repeating 8. So it appears that even though the Decimal values keep many places of precision in the histogram we pass in to the computation function, those values are inaccurate since the repetition ceases after a while, but even that doesn't matter because sympy cuts the values off before that point. Since sympy.solve coerces my input python Decimals into sympy Floats, not sympy Rationals, is it reasonable to expect that this is the source of the issue? If yes, how could I modify my code to allow the computation to resolve properly instead of getting stuck forever?


Solution

  • When using SymPy, float numbers are the cause of many headache. Use exact number whenever possible. If you convert your decimal number to rational number, using nsimplify, the solution will be computed immediatly.

    Change these lines of code:

    expression = Mul(*product_terms)
    # ...
    print("Solutions:", solutions)
    

    to:

    expression = Mul(*product_terms).nsimplify()
    # ...
    print("Solutions:", solutions[0].n())