import numpy as np
from scipy.optimize import minimize
np.random.seed(42)
nDim = 24
xBase = np.random.normal(0, 1, nDim)
x0 = np.zeros(nDim)
loss = lambda x: np.linalg.norm(x - xBase)
# loss = lambda x: (x - xBase).dot(x - xBase)
res = minimize(loss, x0, method = 'BFGS', options={'gtol': 1.0E-3, 'maxiter': 1000})
Here, I minimize the most basic quadratic loss function in 24 dimensions. When I use squared quadratic loss (commented out), the method converges no problem. When I take square root of the objective function, the method fails, even though it is essentially the same problem.
message: Desired error not necessarily achieved due to precision loss.
success: False
status: 2
fun: 1.8180001548836974e-07
x: [ 4.967e-01 -1.383e-01 ... 6.753e-02 -1.425e+00]
nit: 4
jac: [-7.681e-02 4.436e-02 ... 2.302e-01 2.046e-01]
hess_inv: [[ 9.825e-01 1.306e-03 ... 1.217e-02 3.625e-02]
[ 1.306e-03 9.980e-01 ... -1.150e-02 -4.421e-03]
...
[ 1.217e-02 -1.150e-02 ... 6.066e-01 7.074e-02]
[ 3.625e-02 -4.421e-03 ... 7.074e-02 8.894e-01]]
nfev: 1612
njev: 64
The documentation for BFGS suggests modifying the gtol
argument, which I have tried. However, it has no effect, unless I set it to some ridiculously high number like 1.0E+1, in which case the answer is just wrong. In all other cases for gtol within range of 1.0E-1 to 1.0E-10, the exact message as above is returned. The answer is actually correct, because the objective function is 1.0E-7, but the error message still says that optimization failed.
Am I doing something wrong, or is this a bug?
I have had a look at this related question, but my objective function is significantly simpler than in that case, so I suspect that the concerns mentioned in the answers should not apply here.
NOTE: My goal is not to get Scipy to work for this toy problem. My goal is to learn to manipulate the arguments of minimize
function to increase the likelihood that it solves my actual problem, which is far more complicated.
There are four ways to fix the problem that you're having in this code.
The function np.linalg.norm(x)
doesn't have a smooth first derivative. Here a plot of this for one variable, from -1 to 1.
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(-1, 1, 1000)
y = [np.linalg.norm(i) for i in x]
plt.plot(x, y)
plt.xlabel('x')
plt.ylabel('sqrt(x**2)')
The same issue happens in a multi-variable context. (If this doesn't make sense to you, imagine the solver has worked out perfect answers to all but one of the variables. The remaining variable reduces to the above case.)
This is a problem, because when minimize()
differentiates your function, there are only two values: -1, and 1. This means that gtol will never terminate optimization.
An alternative to gtol is xrtol, the X Relative TOLerance termination condition.
res = minimize(loss, x0, method = 'BFGS', options={'gtol': 1.0E-3, 'maxiter': 1000, 'xrtol': 1e-9})
This terminates optimization if the relative change in x from the next step would be smaller than one part in one billion. This also allows optimization to succeed.
To make it easier to fit data, most optimizers which fit data iteratively use the square of the L2 norm, not the L2 norm itself. For example, SGDRegressor uses squared_error
as its default loss term. This produces the same fit, but has nicer derivatives.
In other words, I would recommend lambda x: (x - xBase).dot(x - xBase)
over np.linalg.norm(x - xBase)
.
SciPy approximates a function's derivative using numerical differentiation. In this case, it is not very accurate, and can be improved by adding a an analytic derivative.
This produces a better fit, and faster.
import numpy as np
from scipy.optimize import minimize
np.random.seed(42)
nDim = 24
xBase = np.random.normal(0, 1, nDim)
x0 = np.zeros(nDim)
def loss(x):
return np.linalg.norm(x - xBase)
def loss_jac(x):
return 2*(x - xBase)
res = minimize(loss, x0,
jac=loss_jac,
method = 'BFGS',
options={
'gtol': 1.0E-3,
'maxiter': 1000,
})
Note: technically, this is wrong derivative. This is the derivative of the square of loss, not the derivative of loss itself. The correct derivative is sign(x - xBase)
, but I found this didn't work.
A final alternative would be to use a solver which tolerates nondifferentiable functions, like COBYQA or Powell.