Let's say I have a function foo(x, a, b)
and I want to find a specific one of its (potentially multiple) roots x0
, i.e. a value x0
such that foo(x0, a, b) == 0
. I know that for (a, b) == (0, 0)
the root I want is x0 == 0
and that the function changes continuously with a
and b
, so I can "follow" the root from (0, 0)
to the desired (a, b)
.
Here's an example function.
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
For (a, b) == (0, 0)
I want to the root at 0
, for (2, 0)
I want the one near 1.5
and for (2, 1)
I want the one near 2.2
.
Now, this problem seems like one that may be common enough to have a prepared, fast solver in scipy
or another standard package (or tools to easily and efficiently construct one). However, I don't know what terms to search for to find it (or verify that it doesn't exist). Is there a ready-made tool for this? What is this kind of problem called?
Clarifications:
(a, b)
, the "correct" root may disappear (e.g. for (1, 3)
in the example). When this happens, returning nan
is the preferred behavior, though this is not super important.(a, b)
can be quickly solved, not just a single one. I will go on calculating the root for a lot of different parameters, e.g. for plotting and integrating over them.Here's a quickly put together reference implementation that does pretty much what I want for the example above (and creates the plot). It's not very fast, of course, which somewhat limits me in my actual application.
import functools
import numpy as np
from scipy.optimize import root_scalar
from matplotlib import pyplot as plt
def foo(x, a, b):
return (1 + a) * np.sin(a + b - x) - x
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("x")
ax.set_ylabel("foo")
x = np.linspace(-np.pi, np.pi, 201)
ax.plot(x, foo(x, 0, 0), label="(a, b) = (0, 0)")
ax.plot(x, foo(x, 2, 0), label="(a, b) = (2, 0)")
ax.plot(x, foo(x, 2, 1), label="(a, b) = (2, 1)")
ax.legend()
plt.show()
# Semi-bodged solver for reference:
def _dfoo(x, a, b):
return -(1 + a) * np.cos(a + b - x) - 1
def _solve_fooroot(guess, a, b):
if np.isnan(guess):
return np.nan
# Determine limits for finding the root.
# Allow for slightly larger limits to account for numerical imprecision.
maxlim = 1.1 * (1 + a)
y0 = foo(guess, a, b)
if y0 == 0:
return guess
dy0 = _dfoo(guess, a, b)
estimate = -y0 / dy0
# Too small estimates are no good.
if np.abs(estimate) < 1e-2 * maxlim:
estimate = np.sign(estimate) * 1e-2 * maxlim
for lim in np.arange(guess, guess + 10 * estimate, 1e-1 * estimate):
if np.sign(foo(lim, a, b)) != np.sign(y0):
bracket = sorted([guess, lim])
break
else:
return np.nan
sol = root_scalar(foo, (a, b), bracket=bracket)
return sol.root
@functools.cache
def _fooroot(an, astep, bn, bstep):
if an == 0:
if bn == 0:
return 0
guessan, guessbn = an, bn - 1
else:
guessan, guessbn = an - 1, bn
# Avoid reaching maximum recursion depth.
for thisbn in range(0, guessbn, 100):
_fooroot(0, astep, thisbn, bstep)
for thisan in range(0, guessan, 100):
_fooroot(thisan, astep, guessbn, bstep)
guess = _fooroot(guessan, astep, guessbn, bstep)
return _solve_fooroot(guess, an * astep, bn * bstep)
@np.vectorize
def fooroot(a, b):
astep = (-1 if a < 0 else 1) * 1e-2
bstep = (-1 if b < 0 else 1) * 1e-2
guess = _fooroot(int(a / astep), astep, int(b / bstep), bstep)
return _solve_fooroot(guess, a, b)
print(fooroot(0, 0))
print(fooroot(2, 0))
print(fooroot(2, 1))
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("b")
ax.set_ylabel("fooroot(a, b)")
b = np.linspace(-3, 3, 201)
for a in [0, 0.2, 0.5]:
ax.plot(b, fooroot(a, b), label=f"a = {a}")
ax.legend()
plt.show()
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel("a")
ax.set_ylabel("b")
a = np.linspace(-1, 1, 201)
b = np.linspace(-3.5, 3.5, 201)
aa, bb = np.meshgrid(a, b)
pcm = ax.pcolormesh(aa, bb, fooroot(aa, bb))
fig.colorbar(pcm, label="fooroot(a, b)")
plt.show()
Get rid of your @cache
and @vectorize
; neither is likely to help you for the following and they're just noise. (If they're needed for outer code, you haven't shown that outer code, so the point stands.)
Do keep using Scipy's root-finding iteratively, but beyond that your procedure should look pretty different. I propose:
Get your initial estimate x0, a0, b0. Then in a loop:
root_scalar
with your new estimate, new a and b, passing analytic fprime
and fprime2
, probably using Halley's Method, and having relaxed tolerances that are parametric and appropriate to the function.Of course that's the ideal case, but in practice Scipy limitations mean that the vectorised methods cannot easily use second-order gradients. That in turn means that Halley is unavailable, but even linear Jacobian steps work well.
The following demonstration shows a fully-vectorised path traversal, with two independent start points and stop points in a, b space. This can be arbitrarily extended to as many consecutive paths as you want.
import logging
import typing
import numpy as np
from numpy._typing import ArrayLike
from scipy.optimize import root, check_grad
class Trivariate(typing.Protocol):
def __call__(
self,
x: np.ndarray, a: np.ndarray, b: np.ndarray,
) -> np.ndarray: ...
def foo(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (1 + a)*np.sin(a + b - x) - x
def dfoo_dx(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (-1 - a)*np.cos(a + b - x) - 1
def dfoo_da(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (1 + a)*np.cos(a + b - x) + np.sin(a + b - x)
def dfoo_db(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return (1 + a)*np.cos(a + b - x)
def follow_root(
fun: Trivariate, dfdx: Trivariate, dfda: Trivariate, dfdb: Trivariate,
a0: ArrayLike, a1: ArrayLike,
b0: ArrayLike, b1: ArrayLike,
x0: ArrayLike,
steps: int = 10,
method: str = 'hybr',
follow_tol: float = 1e-2, follow_reltol: float = 1e-2,
polish_tol: float = 1e-12, polish_reltol: float = 1e-12,
) -> np.ndarray:
def dfdx_sparse(x: np.ndarray, a: np.ndarray, b: np.ndarray) -> np.ndarray:
return np.diag(dfdx(x, a, b))
x_est = np.asarray(x0)
ab0 = np.array((a0, b0))
ab1 = np.array((a1, b1))
# (number of steps, ab, dimensions of a0...) = (n-1, 2, ...)
ab = np.linspace(start=ab0, stop=ab1, num=steps)
da = ab[1, 0] - ab[0, 0]
db = ab[1, 1] - ab[0, 1]
for i, ((ai, bi), (ai1, bi1)) in enumerate(zip(ab[:-1], ab[1:])):
dfdxi = dfdx(x_est, ai, bi)
dxda = dfda(x_est, ai, bi)/dfdxi
dxdb = dfdb(x_est, ai, bi)/dfdxi
# If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
step = -dxda*da - dxdb*db
last = i == len(ab) - 2
result = root(
fun=fun, args=(ai1, bi1), jac=dfdx_sparse, x0=x_est + step, method=method,
tol=polish_tol if last else follow_tol,
options={
'xtol': polish_reltol if last else follow_reltol,
'col_deriv': True,
},
)
if not result.success:
raise ValueError('Root finding failed: ' + result.message)
logging.debug('#%d x%d %s +%s ~ %s = %s', i, result.nfev, x_est, step, x_est + step, result.x)
x_est = result.x
return x_est
def main() -> None:
# Don't do this in production!
logging.getLogger().setLevel(logging.DEBUG)
err = check_grad(
lambda x: foo(x, np.array((0, 2)), np.array((0, 0))),
lambda x: np.diag(dfoo_dx(x, np.array((0, 2)), np.array((0, 0)))),
(0.1, 1.5), # x0
)
assert err < 1e-7
err = check_grad(
lambda a: foo(np.array((0.1, 1.5)), a, np.array((0, 0))),
lambda a: np.diag(dfoo_da(np.array((0.1, 1.5)), a, np.array((0, 0)))),
(0, 2), # a0
)
assert err < 1e-7
err = check_grad(
lambda b: foo(np.array((0.1, 1.5)), np.array((0, 2)), b),
lambda b: np.diag(dfoo_db(np.array((0.1, 1.5)), np.array((0, 2)), b)),
(0, 0), # b0
)
assert err < 1e-7
follow_root(
fun=foo, dfdx=dfoo_dx, dfda=dfoo_da, dfdb=dfoo_db,
a0=(0, 2), a1=(1.7, 2),
b0=(0, 0), b1=(0 , 1),
x0=(0, 1.5),
)
if __name__ == '__main__':
main()
DEBUG:root:#0 x5 [0. 1.5] +[0.09444444 0.08052514] ~ [0.09444444 1.58052514] = [0.1025362 1.56306428]
DEBUG:root:#1 x4 [0.1025362 1.56306428] +[0.10987707 0.07990566] ~ [0.21241326 1.64296994] = [0.21851137 1.64274921]
DEBUG:root:#2 x4 [0.21851137 1.64274921] +[0.12155443 0.07945781] ~ [0.3400658 1.72220702] = [0.34478035 1.72196514]
DEBUG:root:#3 x4 [0.34478035 1.72196514] +[0.13061949 0.0789664 ] ~ [0.47539984 1.80093153] = [0.47912896 1.80066588]
DEBUG:root:#4 x4 [0.47912896 1.80066588] +[0.13781338 0.07842657] ~ [0.61694234 1.87909245] = [0.61994978 1.87880026]
DEBUG:root:#5 x4 [0.61994978 1.87880026] +[0.14363144 0.07783262] ~ [0.76358122 1.95663288] = [0.76604746 1.95631084]
DEBUG:root:#6 x4 [0.76604746 1.95631084] +[0.14841417 0.07717773] ~ [0.91446163 2.03348857] = [0.9165135 2.03313271]
DEBUG:root:#7 x4 [0.9165135 2.03313271] +[0.15240176 0.07645371] ~ [1.06891526 2.10958643] = [1.07064402 2.10919196]
DEBUG:root:#8 x7 [1.07064402 2.10919196] +[0.15576766 0.07565065] ~ [1.22641168 2.1848426 ] = [1.22788398 2.18440359]
This works fine with a reduced number of steps; with only four steps:
DEBUG:root:#0 x5 [0. 1.5] +[0.28333333 0.24157542] ~ [0.28333333 1.74157542] = [0.34477692 1.72196632]
DEBUG:root:#1 x5 [0.34477692 1.72196632] +[0.39185914 0.23689925] ~ [0.73663606 1.95886557] = [0.76604622 1.95631082]
DEBUG:root:#2 x8 [0.76604622 1.95631082] +[0.44524268 0.23153319] ~ [1.2112889 2.18784401] = [1.22788398 2.18440359]
Keep the main idea (and its gradients); build a row-wise output:
import typing
from functools import partial
import matplotlib.pyplot as plt
import numpy as np
from numpy._typing import ArrayLike
from scipy.optimize import root, check_grad
class Trivariate(typing.Protocol):
def __call__(
self, x: np.ndarray, ab: np.ndarray,
) -> np.ndarray: ...
def foo(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
a, b = ab
return (1 + a)*np.sin(a + b - x) - x
def dfoo_dx(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
a, b = ab
return (-1 - a)*np.cos(a + b - x) - 1
def dfoo_dab(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
a, b = ab
abx = a + b - x
a1cos = (1 + a)*np.cos(abx)
return np.stack((a1cos + np.sin(abx), a1cos))
def next_roots(
baked_root, dfdx: Trivariate, dfdab: Trivariate,
dab: np.ndarray,
ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
) -> np.ndarray:
dfdxi = dfdx(x0, ab0)
dxdab = dfdab(x0, ab0)/dfdxi
# If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
step = (-dab) @ dxdab
result = baked_root(args=ab1, x0=x0 + step)
if not result.success:
raise ValueError('Root finding failed: ' + result.message)
return result.x
def roots_2d(
fun: Trivariate, dfdx: Trivariate, dfdab: Trivariate,
a0: float, a1: float,
b0: float, b1: float,
centre_estimate: float,
resolution: int = 201,
method: str = 'hybr',
tol: float = 1e-2, reltol: float = 1e-2,
dtype: np.dtype = np.float32,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
def dfdx_sparse(x: np.ndarray, ab: np.ndarray) -> np.ndarray:
return np.diag(dfdx(x, ab))
baked_root = partial(
root, fun=fun, jac=dfdx_sparse, method=method, tol=tol,
options={'xtol': reltol, 'col_deriv': True},
)
baked_next = partial(
next_roots, baked_root=baked_root, dfdx=dfdx, dfdab=dfdab,
)
aser = np.linspace(start=a0, stop=a1, num=resolution, dtype=dtype)
bser = np.linspace(start=b0, stop=b1, num=resolution, dtype=dtype)
da = aser[1] - aser[0]
db = bser[1] - bser[0]
aa, bb = np.meshgrid(aser, bser)
aabb = np.stack((aa, bb)) # (2, 201, 201): (ab, b index, a index)
out = np.empty_like(aa)
# Centre point, the only one for which we rely on an estimate from the caller
imid = resolution//2
out[imid, imid] = baked_root(args=aabb[:, imid, imid], x0=centre_estimate).x.squeeze()
# Centre to right, scalars
dar = np.array((da, 0), dtype=da.dtype)
for j in range(imid + 1, resolution):
out[imid, j] = baked_next(
dab=dar, ab0=aabb[:, imid, j-1], ab1=aabb[:, imid, j], x0=out[imid, j-1],
).squeeze()
# Centre to left, scalars
dal = -dar
for j in range(imid - 1, -1, -1):
out[imid, j] = baked_next(
dab=dal, ab0=aabb[:, imid, j+1], ab1=aabb[:, imid, j], x0=out[imid, j+1],
).squeeze()
# Down rows
dbd = np.array((0, db), dtype=db.dtype)
for i in range(imid + 1, resolution):
out[i] = baked_next(
dab=dbd, ab0=aabb[:, i-1], ab1=aabb[:, i], x0=out[i-1],
)
# Up rows
dbu = -dbd
for i in range(imid - 1, -1, -1):
out[i] = baked_next(
dab=dbu, ab0=aabb[:, i+1], ab1=aabb[:, i], x0=out[i+1],
)
return aa, bb, out
def plot(aa: np.ndarray, bb: np.ndarray, x: np.ndarray) -> plt.Figure:
fig, ax = plt.subplots()
ax.grid()
ax.set_xlabel('a')
ax.set_ylabel('b')
mesh = ax.pcolormesh(aa, bb, x, vmin=-3, vmax=3)
fig.colorbar(mesh, label='root')
return fig
def main() -> None:
x0 = np.array((0.1, 1.5))
ab0 = np.array([(0.3, 2), (0.1, 0.2)])
err = check_grad(
partial(foo, ab=ab0),
lambda x: np.diag(dfoo_dx(x, ab0)),
x0,
)
assert err < 1e-7
# err = check_grad(
# partial(foo, x0),
# lambda ab: dfoo_dab(x0, ab),
# ab0,
# )
# assert err < 1e-7
aa, bb, x = roots_2d(
fun=foo, dfdx=dfoo_dx, dfdab=dfoo_dab,
a0=-1, a1=1,
b0=-3, b1=3,
centre_estimate=0,
)
plot(aa, bb, x)
plt.show()
if __name__ == '__main__':
main()
Executes in a second or two:
To take care of disappearing roots, there are no perfect solutions. Either you need to write a flood fill algorithm, which is complicated; or you can just do a simple heuristic like
def next_roots(
baked_root, dfdx: Trivariate, dfdab: Trivariate,
dab: np.ndarray,
ab0: ArrayLike, ab1: ArrayLike, x0: ArrayLike,
error_bound: float = 1e-2,
) -> np.ndarray:
input_mask = np.isfinite(x0)
x0_masked = x0[input_mask]
dfdxi = dfdx(x0_masked, ab0[:, input_mask])
dxdab = dfdab(x0_masked, ab0[:, input_mask])/dfdxi
# If a and b are perturbed, where will x go? This is linear, but it can be extended to second-order.
step = (-dab) @ dxdab
xest = x0_masked + step
result = baked_root(args=ab1[:, input_mask], x0=xest)
if not result.success:
raise ValueError('Root finding failed: ' + result.message)
xsol = result.x
est_error = np.abs(xsol - xest)
xsol[est_error > error_bound] = np.nan
xnew = np.full_like(x0, fill_value=np.nan)
xnew[input_mask] = xsol
return xnew