pythongradientgradient-descentautograd

SIR parameter estimation with gradient descent and autograd


I am trying to apply a very simple parameter estimation of a SIR model using a gradient descent algorithm. I am using the package autograd since the audience (this is for a sort of workshop for undergraduate students) only knows numpy and I don't want to jump to JAX or any other ML framework (yet).

import autograd
import autograd.numpy as np
import matplotlib.pyplot as plt

from scipy.integrate import solve_ivp, odeint
from autograd.builtins import tuple
from autograd import grad, jacobian


def sir(y, t, beta, gamma):
  S, I, R = y
  dS_dt = - beta * S * I
  dI_dt = beta * S  * I - gamma * I
  dR_dt = gamma * I
  return np.array([dS_dt, dI_dt, dR_dt])

def loss(params, Y0, t, y_obs):
  beta, gamma = params
  # Solve the ODE system using odeint
  sol = odeint(sir, y0=Y0, t=t, args=(beta, gamma))

  # Compute the L2 norm error between the observed and predicted values
  err = np.linalg.norm(y_obs - sol, 2)
  return err

# Generate data
np.random.seed(42)
Y0 = np.array([0.95, 0.05, 0.0])
t = np.linspace(0, 30, 101)
beta, gamma = 0.5, 1/14
sol = odeint(sir, y0=Y0, t=t, args=tuple([beta, gamma]))
y_obs = sol + np.random.normal(0, 0.05, size=sol.shape)
plt.plot(t, y_obs)

Then, what I would like to do is something like this

# --- THIS DOES NOT WORK ---
params = np.array([beta_init, gamma_init])
# Get the gradient of the loss function with respect to the parameters (beta, gamma)
loss_grad = grad(loss, argnum=0)  # params is the first argument of loss
# Perform gradient descent
for i in range(n_iterations):
  grads = loss_grad(params, Y0, t, y_obs)  # Compute gradients
  params -= learning_rate * grads  # Update parameters

A minimal example would be:

loss_grad = grad(loss, argnum=0)
params = np.array([beta, gamma])
grads = loss_grad(params, Y0, t, y_obs)

However, I get the following error:

ValueError: setting an array element with a sequence.

which start at the very be

Is there any way I can calculate the derivatives of the loss function with respect to my parameters (beta and gamma)? To be honest I am still getting used to auto-differentiation.


Solution

  • This is a modified version of your code that seems to work

    import autograd
    import autograd.numpy as np
    import matplotlib.pyplot as plt
    
    from autograd.scipy.integrate import odeint
    from autograd.builtins import tuple
    from autograd import grad, jacobian
    
    
    def sir(y, t, beta, gamma):
      S, I, R = y
      dS_dt = - beta * S * I
      dI_dt = beta * S  * I - gamma * I
      dR_dt = gamma * I
      return np.array([dS_dt, dI_dt, dR_dt])
    
    def loss(params, Y0, t, y_obs):
      params_tuple = tuple(params)
      # Solve the ODE system using odeint
      sol = odeint(sir, Y0, t, params_tuple)
    
      # Compute the L2 norm error between the observed and predicted values
      err = np.linalg.norm(y_obs - sol)
      return err
    
    # Generate data
    np.random.seed(42)
    Y0 = np.array([0.95, 0.05, 0.0])
    t = np.linspace(0, 30, 101)
    beta, gamma = 0.5, 1/14
    sol = odeint(sir, y0=Y0, t=t, args=(beta, gamma))
    y_obs = sol + np.random.normal(0, 0.05, size=sol.shape)
    plt.plot(t, y_obs)
    

    Then, when running

    loss_grad = grad(loss, argnum=0)
    params = np.array([beta, gamma])
    grads = loss_grad(params, Y0, t, y_obs)
    print(grads)
    

    I get the output [-0.84506353 -7.09399783].

    The important differences in this code are:

    from autograd.scipy.integrate import odeint
    

    The wrapping defines the gradient of odeint and declares to autograd that it should be treated as a differentiation primitive, rather than have its execution traced.

    sol = odeint(sir, Y0, t, params_tuple)
    

    This is done because autograd currently has an issue where it will incorrectly run differentiation primitives with keyword arguments using the wrapped type of object it uses for tracing execution, as reported here. This is a problem if the function you are using is incompatible with the wrapped object, as in this case, raising an error.

    Hope this helps!