pythonnon-linear-regressionconvergencejax

Parameters do not converge at a lower tolerance in nonlinear least square implementation in python


I am translating some of my R codes to Python as a learning process, especially trying JAX for autodiff.

In functions to implement non-linear least square, when I set tolerance at 1e-8, the estimated parameters are nearly identical after several iterations, but the algorithm never appear to converge.

However, the R codes converge at the 12th inter at tol=1e-8 and 14th inter at tol=1e-9. The estimated parameters are almost the same as the ones resulted from Python implementation.

I think this has something to do with floating point, but not sure which step I could improve to make the converge as quickly as seen in R.

Here are my codes, and most steps are the same as in R

import jax
import jax.numpy as jnp
import numpy as np
import scipy.linalg as ola


def update_parm(X, y, fun, dfun, parm, theta, wt):
    len_y = len(y)
    mean_fun = fun(X, parm)

    if (type(wt) == bool):
        if (wt):
            var_fun = np.exp(theta * np.log(mean_fun))
            sqrtW = 1 / np.sqrt(var_fun ** 2)
        else:
            sqrtW = 1
    else:
        sqrtW = wt
        
    gradX = dfun(x, parm)
    weighted_X = sqrtW.reshape(len_y, 1) * gradX
    z = gradX @ parm + (y - mean_fun)
    weighted_z = sqrtW * z
    qr_gradX = ola.qr(weighted_X, mode="economic")
    Q = qr_gradX[0]
    R = qr_gradX[1]
    new_parm = ola.solve(R, np.dot(Q.T, weighted_z))
    
    return new_parm


def nls_irwls(X, y, fun, dfun, init, theta = 1, tol = 1e-8, maxiter = 500):

    old_parm = init
    iter = 0

    while (iter < maxiter):
        new_parm = update_parm(X, y, fun, dfun, parm=old_parm, theta=theta, wt=True)
        parm_diff = np.max(np.abs(new_parm - old_parm) / np.abs(old_parm))
        print(parm_diff)

        if (parm_diff < tol) :
            break
        else:
            old_parm = new_parm
            iter += 1
            print(new_parm)

    if (iter == maxiter):
        print("The algorithm failed to converge")
    else:
        return {"Estimated coefficient": new_parm}


x = np.array([0.25, 0.5, 0.75, 1, 1.25, 2, 3, 4, 5, 6, 8])
y = np.array([2.05, 1.04, 0.81, 0.39, 0.30, 0.23, 0.13, 0.11, 0.08, 0.10, 0.06])

def model(x, W):
    comp1 = jnp.exp(W[0])
    comp2 = jnp.exp(-jnp.exp(W[1]) * x)
    comp3 = jnp.exp(W[2])
    comp4 = jnp.exp(-jnp.exp(W[3]) * x)
    return comp1 * comp2 + comp3 * comp4


init = np.array([0.69, 0.69, -1.6, -1.6])

#autodiff
model_grad = jax.jit(jax.jacfwd(model, argnums=1))

#manual derivative
def dModel(x, W):
    e1 = np.exp(W[1])
    e2 = np.exp(W[3])
    e5 = np.exp(-(x * e1))
    e6 = np.exp(-(x * e2))
    e7 = np.exp(W[0])
    e8 = np.exp(W[2])
    b1 = e5 * e7
    b2 = -(x * e5 * e7 * e1) 
    b3 = e6 * e8 
    b4 = -(x * e6 * e8 * e2)

    return np.array([b1, b2, b3, b4]).T

nls_irwls(x, y, model, model_grad, init=init, theta=1, tol=1e-8, maxiter=50)
nls_irwls(x, y, model, dModel, init=init, theta=1, tol=1e-8, maxiter=50)

Solution

  • One thing to be aware of is that by default, JAX performs computations in 32-bit, while tools like R and numpy perform computations in 64-bit. Since 1E-8 is at the edge of 32-bit floating point precision, I suspect this is why your program is failing to converge.

    You can enable 64-bit computation by putting this at the beginning of your script:

    from jax import config
    config.update('jax_enable_x64', True)
    

    After doing this, your program converges as expected. For more information, see JAX Sharp Bits: Double Precision.