pythonscipynumerical-methodsscipy-optimize

Find correct root of parametrized function given solution for one set of parameters


Let's say I have a function foo(x, a, b) and I want to find a specific one of its (potentially multiple) roots x0, i.e. a value x0 such that foo(x0, a, b) == 0. I know that for (a, b) == (0, 0) the root I want is x0 == 0 and that the function changes continuously with a and b, so I can "follow" the root from (0, 0) to the desired (a, b).

Here's an example function.

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

example function plot

For (a, b) == (0, 0) I want to the root at 0, for (2, 0) I want the one near 1.5 and for (2, 1) I want the one near 2.2.

Now, this problem seems like one that may be common enough to have a prepared, fast solver in scipy or another standard package (or tools to easily and efficiently construct one). However, I don't know what terms to search for to find it (or verify that it doesn't exist). Is there a ready-made tool for this? What is this kind of problem called?


Clarifications:


Here's a quickly put together reference implementation that does pretty much what I want for the example above (and creates the plot). It's not very fast, of course, which somewhat limits me in my actual application.

import functools
import numpy as np
from scipy.optimize import root_scalar
from matplotlib import pyplot as plt

def foo(x, a, b):
    return (1 + a) * np.sin(a + b - x) - x

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("x")
ax.set_ylabel("foo")
x = np.linspace(-np.pi, np.pi, 201)
ax.plot(x, foo(x, 0, 0), label="(a, b) = (0, 0)")
ax.plot(x, foo(x, 2, 0), label="(a, b) = (2, 0)")
ax.plot(x, foo(x, 2, 1), label="(a, b) = (2, 1)")
ax.legend()
plt.show()

# Semi-bodged solver for reference:

def _dfoo(x, a, b):
    return -(1 + a) * np.cos(a + b - x) - 1

def _solve_fooroot(guess, a, b):
    if np.isnan(guess):
        return np.nan
    # Determine limits for finding the root.
    # Allow for slightly larger limits to account for numerical imprecision.
    maxlim = 1.1 * (1 + a)
    y0 = foo(guess, a, b)
    if y0 == 0:
        return guess
    dy0 = _dfoo(guess, a, b)
    estimate = -y0 / dy0
    # Too small estimates are no good.
    if np.abs(estimate) < 1e-2 * maxlim:
        estimate = np.sign(estimate) * 1e-2 * maxlim
    for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
        if np.sign(foo(lim, a, b)) != np.sign(y0):
            bracket = sorted([guess, lim])
            break
    else:
        return np.nan
    sol = root_scalar(foo, (a, b), bracket=bracket)
    return sol.root

@functools.cache
def _fooroot(an, astep, bn, bstep):
    if an == 0:
        if bn == 0:
            return 0
        guessan, guessbn = an, bn - 1
    else:
        guessan, guessbn = an - 1, bn
    # Avoid reaching maximum recursion depth.
    for thisbn in range(0, guessbn, 100):
        _fooroot(0, astep, thisbn, bstep)
    for thisan in range(0, guessan, 100):
        _fooroot(thisan, astep, guessbn, bstep)
    guess = _fooroot(guessan, astep, guessbn, bstep)
    return _solve_fooroot(guess, an * astep, bn * bstep)

@np.vectorize
def fooroot(a, b):
    astep = (-1 if a < 0 else 1) * 1e-2
    bstep = (-1 if b < 0 else 1) * 1e-2
    guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
    return _solve_fooroot(guess, a, b)

print(fooroot(0, 0))
print(fooroot(2, 0))
print(fooroot(2, 1))

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("b")
ax.set_ylabel("fooroot(a, b)")
b = np.linspace(-3, 3, 201)
for a in [0, 0.2, 0.5]:
    ax.plot(b, fooroot(a, b), label=f"a = {a}")
ax.legend()
plt.show()

fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("a")
ax.set_ylabel("b")
a = np.linspace(-1, 1, 201)
b = np.linspace(-3.5, 3.5, 201)
aa, bb = np.meshgrid(a, b)
pcm = ax.pcolormesh(aa, bb, fooroot(aa, bb))
fig.colorbar(pcm, label="fooroot(a, b)")
plt.show()

reference solution plot

reference solution colormesh


Solution

  • Get rid of your @cache and @vectorize; neither is likely to help you for the following and they're just noise. (If they're needed for outer code, you haven't shown that outer code, so the point stands.)

    Do keep using Scipy's root-finding iteratively, but beyond that your procedure should look pretty different. I propose:

    Get your initial estimate x0, a0, b0. Then in a loop:

    1. Infer by analytic integration the paraboloid intersecting the current point whose first and second derivatives with respect to a and b match those of f.
    2. Increment a and b at their fixed step size. If the function is smooth and parabolic step estimation works well then this step may be somewhat large; but if you're writing this for a generic routine that can take any function then it must be parametric.
    3. Call root_scalar with your new estimate, new a and b, passing analytic fprime and fprime2, probably using Halley's Method, and having relaxed tolerances that are parametric and appropriate to the function.

    Of course that's the ideal case, but in practice Scipy limitations mean that the vectorised methods cannot easily use second-order gradients. That in turn means that Halley is unavailable, but even linear Jacobian steps work well.

    The following demonstration shows a fully-vectorised path traversal, with two independent start points and stop points in a, b space. This can be arbitrarily extended to as many consecutive paths as you want.

    import logging
    import typing
    
    import numpy as np
    from numpy._typing import ArrayLike
    from scipy.optimize import root, check_grad
    
    
    class Trivariate(typing.Protocol):
        def __call__(
            self,
            x: np.ndarray, a: np.ndarray, b: np.ndarray,
        ) -> np.ndarray: ...
    
    
    def foo(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return (1 + a)*np.sin(a + b - x) - x
    
    
    def dfoo_dx(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return (-1 - a)*np.cos(a + b - x) - 1
    
    
    def dfoo_da(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return (1 + a)*np.cos(a + b - x) + np.sin(a + b - x)
    
    
    def dfoo_db(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
        return (1 + a)*np.cos(a + b - x)
    
    
    def follow_root(
        fun: Trivariate, dfdx: Trivariate, dfda: Trivariate, dfdb: Trivariate,
        a0: ArrayLike, a1: ArrayLike,
        b0: ArrayLike, b1: ArrayLike,
        x0: ArrayLike,
        steps: int = 10,
        method: str = 'hybr',
        follow_tol: float = 1e-2, follow_reltol: float = 1e-2,
        polish_tol: float = 1e-12, polish_reltol: float = 1e-12,
    ) -> np.ndarray:
        def dfdx_sparse(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
            return np.diag(dfdx(x, a, b))
    
        x_est = np.asarray(x0)
        ab0 = np.array((a0, b0))
        ab1 = np.array((a1, b1))
    
        # (number of steps, ab, dimensions of a0...) = (n-1, 2, ...)
        ab = np.linspace(start=ab0, stop=ab1, num=steps)
        da = ab[1, 0] - ab[0, 0]
        db = ab[1, 1] - ab[0, 1]
    
        for i, ((ai, bi), (ai1, bi1)) in enumerate(zip(ab[:-1], ab[1:])):
            dfdxi = dfdx(x_est, ai, bi)
            dxda = dfda(x_est, ai, bi)/dfdxi
            dxdb = dfdb(x_est, ai, bi)/dfdxi
            # If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
            step = -dxda*da - dxdb*db
            last = i == len(ab) - 2
    
            result = root(
                fun=fun, args=(ai1, bi1), jac=dfdx_sparse, x0=x_est + step, method=method,
                tol=polish_tol if last else follow_tol,
                options={
                    'xtol': polish_reltol if last else follow_reltol,
                    'col_deriv': True,
                },
            )
            if not result.success:
                raise ValueError('Root finding failed: ' + result.message)
    
            logging.debug('#%d x%d %s +%s ~ %s = %s', i, result.nfev, x_est, step, x_est + step, result.x)
            x_est = result.x
    
        return x_est
    
    
    def main() -> None:
        # Don't do this in production!
        logging.getLogger().setLevel(logging.DEBUG)
    
        err = check_grad(
            lambda x: foo(x, np.array((0, 2)), np.array((0, 0))),
            lambda x: np.diag(dfoo_dx(x, np.array((0, 2)), np.array((0, 0)))),
            (0.1, 1.5),  # x0
        )
        assert err < 1e-7
    
        err = check_grad(
            lambda a: foo(np.array((0.1, 1.5)), a, np.array((0, 0))),
            lambda a: np.diag(dfoo_da(np.array((0.1, 1.5)), a, np.array((0, 0)))),
            (0, 2),  # a0
        )
        assert err < 1e-7
    
        err = check_grad(
            lambda b: foo(np.array((0.1, 1.5)), np.array((0, 2)), b),
            lambda b: np.diag(dfoo_db(np.array((0.1, 1.5)), np.array((0, 2)), b)),
            (0, 0),  # b0
        )
        assert err < 1e-7
    
        follow_root(
            fun=foo, dfdx=dfoo_dx, dfda=dfoo_da, dfdb=dfoo_db,
            a0=(0, 2), a1=(1.7, 2),
            b0=(0, 0), b1=(0  , 1),
            x0=(0, 1.5),
        )
    
    
    if __name__ == '__main__':
        main()
    
    DEBUG:root:#0 x5 [0.  1.5] +[0.09444444 0.08052514] ~ [0.09444444 1.58052514] = [0.1025362  1.56306428]
    DEBUG:root:#1 x4 [0.1025362  1.56306428] +[0.10987707 0.07990566] ~ [0.21241326 1.64296994] = [0.21851137 1.64274921]
    DEBUG:root:#2 x4 [0.21851137 1.64274921] +[0.12155443 0.07945781] ~ [0.3400658  1.72220702] = [0.34478035 1.72196514]
    DEBUG:root:#3 x4 [0.34478035 1.72196514] +[0.13061949 0.0789664 ] ~ [0.47539984 1.80093153] = [0.47912896 1.80066588]
    DEBUG:root:#4 x4 [0.47912896 1.80066588] +[0.13781338 0.07842657] ~ [0.61694234 1.87909245] = [0.61994978 1.87880026]
    DEBUG:root:#5 x4 [0.61994978 1.87880026] +[0.14363144 0.07783262] ~ [0.76358122 1.95663288] = [0.76604746 1.95631084]
    DEBUG:root:#6 x4 [0.76604746 1.95631084] +[0.14841417 0.07717773] ~ [0.91446163 2.03348857] = [0.9165135  2.03313271]
    DEBUG:root:#7 x4 [0.9165135  2.03313271] +[0.15240176 0.07645371] ~ [1.06891526 2.10958643] = [1.07064402 2.10919196]
    DEBUG:root:#8 x7 [1.07064402 2.10919196] +[0.15576766 0.07565065] ~ [1.22641168 2.1848426 ] = [1.22788398 2.18440359]
    

    This works fine with a reduced number of steps; with only four steps:

    DEBUG:root:#0 x5 [0.  1.5] +[0.28333333 0.24157542] ~ [0.28333333 1.74157542] = [0.34477692 1.72196632]
    DEBUG:root:#1 x5 [0.34477692 1.72196632] +[0.39185914 0.23689925] ~ [0.73663606 1.95886557] = [0.76604622 1.95631082]
    DEBUG:root:#2 x8 [0.76604622 1.95631082] +[0.44524268 0.23153319] ~ [1.2112889  2.18784401] = [1.22788398 2.18440359]
    

    Grid following

    Keep the main idea (and its gradients); build a row-wise output:

    import typing
    from functools import partial
    
    import matplotlib.pyplot as plt
    import numpy as np
    from numpy._typing import ArrayLike
    from scipy.optimize import root, check_grad
    
    
    class Trivariate(typing.Protocol):
        def __call__(
            self, x: np.ndarray, ab: np.ndarray,
        ) -> np.ndarray: ...
    
    
    def foo(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
        a, b = ab
        return (1 + a)*np.sin(a + b - x) - x
    
    
    def dfoo_dx(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
        a, b = ab
        return (-1 - a)*np.cos(a + b - x) - 1
    
    
    def dfoo_dab(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
        a, b = ab
        abx = a + b - x
        a1cos = (1 + a)*np.cos(abx)
        return np.stack((a1cos + np.sin(abx), a1cos))
    
    
    def next_roots(
        baked_root, dfdx: Trivariate, dfdab: Trivariate,
        dab: np.ndarray,
        ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
    ) -> np.ndarray:
        dfdxi = dfdx(x0, ab0)
        dxdab = dfdab(x0, ab0)/dfdxi
        # If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
        step = (-dab) @ dxdab
    
        result = baked_root(args=ab1, x0=x0 + step)
        if not result.success:
            raise ValueError('Root finding failed: ' + result.message)
    
        return result.x
    
    
    def roots_2d(
        fun: Trivariate, dfdx: Trivariate, dfdab: Trivariate,
        a0: float, a1: float,
        b0: float, b1: float,
        centre_estimate: float,
        resolution: int = 201,
        method: str = 'hybr',
        tol: float = 1e-2, reltol: float = 1e-2,
        dtype: np.dtype = np.float32,
    ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
        def dfdx_sparse(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
            return np.diag(dfdx(x, ab))
    
        baked_root = partial(
            root, fun=fun, jac=dfdx_sparse, method=method, tol=tol,
            options={'xtol': reltol, 'col_deriv': True},
        )
        baked_next = partial(
            next_roots, baked_root=baked_root, dfdx=dfdx, dfdab=dfdab,
        )
    
        aser = np.linspace(start=a0, stop=a1, num=resolution, dtype=dtype)
        bser = np.linspace(start=b0, stop=b1, num=resolution, dtype=dtype)
        da = aser[1] - aser[0]
        db = bser[1] - bser[0]
        aa, bb = np.meshgrid(aser, bser)
        aabb = np.stack((aa, bb))  # (2, 201, 201): (ab, b index, a index)
        out = np.empty_like(aa)
    
        # Centre point, the only one for which we rely on an estimate from the caller
        imid = resolution//2
        out[imid, imid] = baked_root(args=aabb[:, imid, imid], x0=centre_estimate).x.squeeze()
    
        # Centre to right, scalars
        dar = np.array((da, 0), dtype=da.dtype)
        for j in range(imid + 1, resolution):
            out[imid, j] = baked_next(
                dab=dar, ab0=aabb[:, imid, j-1], ab1=aabb[:, imid, j], x0=out[imid, j-1],
            ).squeeze()
    
        # Centre to left, scalars
        dal = -dar
        for j in range(imid - 1, -1, -1):
            out[imid, j] = baked_next(
                dab=dal, ab0=aabb[:, imid, j+1], ab1=aabb[:, imid, j], x0=out[imid, j+1],
            ).squeeze()
    
        # Down rows
        dbd = np.array((0, db), dtype=db.dtype)
        for i in range(imid + 1, resolution):
            out[i] = baked_next(
                dab=dbd, ab0=aabb[:, i-1], ab1=aabb[:, i], x0=out[i-1],
            )
    
        # Up rows
        dbu = -dbd
        for i in range(imid - 1, -1, -1):
            out[i] = baked_next(
                dab=dbu, ab0=aabb[:, i+1], ab1=aabb[:, i], x0=out[i+1],
            )
    
        return aa, bb, out
    
    
    def plot(aa: np.ndarray, bb: np.ndarray, x: np.ndarray) -> plt.Figure:
        fig, ax = plt.subplots()
        ax.grid()
        ax.set_xlabel('a')
        ax.set_ylabel('b')
        mesh = ax.pcolormesh(aa, bb, x, vmin=-3, vmax=3)
        fig.colorbar(mesh, label='root')
        return fig
    
    
    def main() -> None:
        x0 = np.array((0.1, 1.5))
        ab0 = np.array([(0.3, 2), (0.1, 0.2)])
        err = check_grad(
            partial(foo, ab=ab0),
            lambda x: np.diag(dfoo_dx(x, ab0)),
            x0,
        )
        assert err < 1e-7
    
        # err = check_grad(
        #     partial(foo, x0),
        #     lambda ab: dfoo_dab(x0, ab),
        #     ab0,
        # )
        # assert err < 1e-7
    
        aa, bb, x = roots_2d(
            fun=foo, dfdx=dfoo_dx, dfdab=dfoo_dab,
            a0=-1, a1=1,
            b0=-3, b1=3,
            centre_estimate=0,
        )
        plot(aa, bb, x)
        plt.show()
    
    
    if __name__ == '__main__':
        main()
    

    Executes in a second or two:

    root output

    To take care of disappearing roots, there are no perfect solutions. Either you need to write a flood fill algorithm, which is complicated; or you can just do a simple heuristic like

    def next_roots(
        baked_root, dfdx: Trivariate, dfdab: Trivariate,
        dab: np.ndarray,
        ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
        error_bound: float = 1e-2,
    ) -> np.ndarray:
        input_mask = np.isfinite(x0)
        x0_masked = x0[input_mask]
        dfdxi = dfdx(x0_masked, ab0[:, input_mask])
        dxdab = dfdab(x0_masked, ab0[:, input_mask])/dfdxi
        # If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
        step = (-dab) @ dxdab
        xest = x0_masked + step
    
        result = baked_root(args=ab1[:, input_mask], x0=xest)
        if not result.success:
            raise ValueError('Root finding failed: ' + result.message)
    
        xsol = result.x
        est_error = np.abs(xsol - xest)
        xsol[est_error > error_bound] = np.nan
    
        xnew = np.full_like(x0, fill_value=np.nan)
        xnew[input_mask] = xsol
        return xnew
    

    blanked missing roots