roptimizationodemodel-fitting

Fitting model to observed data by obtimising parameters using optim() minimising residual sum of squares


I am trying to fit this very simple 4 species linear Lotka-Volterra competition model to observed data but for some reason when I try the optim() function something with regards to deSolve seems to fail.

# Data
data <- data.frame(Cod = c(0.1966126, 0.1989563, 0.2567677, 0.3158896, 0.4225435, 0.7219856,
                           1.0570824, 0.7266830, 0.6286763, 0.6389475),
                   Herring = c(1.988372, 2.788014, 3.397138, 2.557245, 2.627013, 3.045617, 
                               3.161002, 3.531306, 3.432021, 3.617174),
                   Sprat = c(2.030273, 3.480469, 3.009277, 1.895996, 2.457520, 1.991211, 2.350098,
                             2.118164, 1.693359, 1.869141),
                   Flounder = c(0.4758220, 0.4425532, 0.4185687, 0.4967118, 0.7102515, 0.5733075,
                                0.7404255, 0.5996132, 0.6235977, 0.7187621))
# Model formulation
LLV <- function(time, state, parameters) {
  with(as.list(c(state, parameters)), {
    db1.dt = b1*(r1+a11*b1+a12*b2+a13*b3+a14*b4)
    db2.dt = b2*(r2+a22*b2+a21*b1+a23*b3+a24*b4)
    db3.dt = b3*(r3+a33*b3+a31*b1+a32*b2+a34*b4)
    db4.dt = b4*(r4+a44*b4+a41*b1+a42*b2+a43*b3)
    list(c(db1.dt, db2.dt, db3.dt, db4.dt))
  })
}
# Model input and simulation
# Model input
params <- c(r1 = -0.342085, r2 = 0.6855681, r3 = 2.757769, r4 = 0.9744113,
            a11 = -1.05973762, a12 = 0.09577309, a13 = -0.01915480, a14 = 1.36098939,
            a21 = 0.17533326, a22 = -0.32247342, a23 = 0.03111628, a24 = 0.30212711,
            a31 = 0.5303516, a32 = -0.4869761, a33 = -0.3194882, a34 = -1.5089027,
            a41 = 0.004418133, a42 = 0.163716414, a43 = -0.237873378, a44 = -1.519158802)
ini <- c(b1 = data[1,1], b2 = data[1,2], b3 = data[1,3], b4 = data[1,4])
tmax <- 10
t <- seq(1,tmax,0.1)
# Results and first parameter guess is more or less okay
results <- deSolve::ode(y = ini, times = t, func = LLV, parms = params)
matplot(data, pch = 1)
matplot(x = results[,1], y = results[,-1], type = "l", add = TRUE)

Here I proceed and write a function that minimises the residual sum of squares that when included in optim() with the above initial parameter guess should produce what I am looking for.

min.RSS <- function(data, params) {
  output <- deSolve::ode(y = ini, times = t, func = LLV, parms = params)
  predictions <- exp(output[,-1])
  observations <- data
  return(sum((predictions-observations)^2))
}
result <- optim(par = params, fn = min.RSS, data = data)
fit <- deSolve::ode(y = ini, times = t, func = LLV, parms = result$par)
matplot(x = fit[,1], y = fit[,-1], type = "l", lwd = 3, add = TRUE)

How to solve this problem?


Solution

  • You got a better fit, but you should be very careful with this problem. I went a little crazy and used the (in-development) fitode package to tackle this problem. I fitted the model and got a much better fit, also tried fitting with 100 randomly varying starting points around my best fit. Your residual sum of squares was 1.19; fitode got to 0.29 on the first try, and the best of 100 fits was RSS=0.16. However: these fits are highly unstable. This plot shows the fits to the data and predictions 5 time steps in the future for (1) your fit (dashed lines); (2) fitode initial fit (dotted line); (3) the 100 other fitode fits (the ones within 0.05 RSS of the best fit are solid, the ones worse than that are drawn very lightly).

    You can see that the out-of-sample predictions are mostly crazy. Your fit is actually more stable than some of the better fits - it gets to time step 13 before the entire community crashes - but the bottom line is that a good fit to the data in this case in no way guarantees a sensible answer. It looks like a single one of the 100 fits reaches the end of the prediction time series without collapsing (which seems like a reasonably sensible "common sense" prediction based on the observed time series).

    In order to fit these data reliably, you either need a model with many fewer parameters, or external information supplied in the form of priors, or regularization - some way to make penalize fits that imply 'wiggly' deterministic trajectories, or interaction parameters/growth rates that are unreasonable.

    ## remotes::install_github("parksw3/fitode")
    library(fitode)
    
    ## data with tags for fitode
    data2 <- setNames(data,paste0(names(data),"_obs"))
    data2 <- data.frame(times=seq(nrow(data2)),data2)
    
    
    ## Model formulation (for fitode)
    LV_model <- odemodel(
        name="4-species LV",
        model=list(
            Cod ~ Cod*(r1+a11*Cod+a12*Herring+a13*Sprat+a14*Flounder),
            Herring ~ Herring*(r2+a22*Herring+a21*Cod+a23*Sprat+a24*Flounder),
            Sprat ~ Sprat*(r3+a33*Sprat+a31*Cod+a32*Herring+a34*Flounder),
            Flounder ~ Flounder*(r4+a44*Flounder+a41*Cod+a42*Herring+a43*Sprat)
        ),
        observation=list(
            Cod_obs ~ ols(mean=Cod),
            Herring_obs ~ ols(mean=Herring),
            Sprat_obs ~ ols(mean=Sprat),
            Flounder_obs ~ ols(mean=Flounder)
        ),
        initial=list(
            Cod ~ data2$Cod_obs[1],
            Herring ~ data2$Herring_obs[1],
            Sprat ~ data2$Sprat_obs[1],
            Flounder ~ data2$Flounder_obs[1]
        ),
        link=setNames(rep("identity",length(pars)),pars),
        par= pars
    )
    
    ## plot results
    plotres <- function(p,ODEint="rk",lty=1,
                        dt=0.1,
                        tvec=seq(1,10,by=dt),...) {
        par(las=1, bty="l")
        res <- deSolve::ode(ini, tvec, LLV, p, method=ODEint)
        matplot(res[,1],res[,-1],type="l",lty=lty,...)
        return(invisible(res[,-1]))
    }
    
    f1 <- fitode(
        LV_model,
        data=data2,
        start=params,
        control=list(maxit=1e5,trace=1000)
    )
    
    ## fitode with multistart
    
    ranfit <- function(n,fit,range=0.5) {
        ## 
        rpars <- params*runif(length(params),1-range,1+range)
        newfit <- try(update(fit, start=rpars))
        return(newfit)
    }
    
    cl <- makeCluster(10)
    clusterSetRNGStream(cl = cl, 101)
    clusterExport(cl, c("params","LV_model","data2"))
    clusterEvalQ(cl,invisible(library(fitode)))
    system.time(
        multifit <- parLapply(cl, 1:100, ranfit, fit=f1, tvec=tvec)
    )
    saveRDS(multifit,file="SO65440448_multifit.rds")
    
    ivec <- seq_along(multifit)
    ivec <- ivec[sapply(multifit,function(x) !inherits(x,"try-error"))]
    coef <- pred <- vector("list", length=length(ivec))
    ll <- conv <- rep(NA,length(ivec))
    for (i in seq_along(ivec)) {
        nf <- multifit[[ivec[i]]]
        coef[[i]] <- coef(nf)
        pp <- predict(nf, times=1:10)
        pred[[i]] <- cbind(times=pp[[1]][,1],
                    do.call(cbind,lapply(pp,"[",-1)))
        ll[i] <- logLik(nf)
        conv[i] <- nf@mle2@details$convergence
    }
    
    par(las=1,bty="l")
    matplot(pred[[1]][,1],pred[[1]][,-1],
            type="n",lty=1,ylim=c(0,6),
            xlab="time",ylab="density")
    lthresh <- 0.05
    for (i in 1:length(pred)) {
        good <- ll[i]>(max(ll)-lthresh)
        alpha <- if (good) 0.8 else 0.1
        lwd <- if (good) 2 else 1
        matlines(pred[[i]][,1],pred[[i]][,-1],lty=1,
                 col=adjustcolor(palette()[1:4],alpha.f=alpha),
                 lwd=lwd)
    }
    matpoints(data2[,1],data2[,-1],pch=16,cex=3)
    plotres(optimres$par,add=TRUE, lwd=3,lty=2,dt=1)
    plotres(coef(f1),add=TRUE, lwd=3,lty=3,dt=1)