rggplot2plotregressionmixed-models

How to plot interaction effects of a mixed model in ggplot2?


I would like to get predictions for the interaction effect a model (e.g. using sjPlot::plot_model or ggeffects package and then feed them to ggplot2 to visualize the interaction in ggplot2. Could someone please help with the code to do that? My question is different than other on stack because it inovles getting predictions using standard packages first, and then plotting them.

My current model:

model <- glmmTMB(total_count ~ mean_temp*lwd_duration + (1|year), family = nbinom1, data=df)

summary(model)

Family: nbinom1  ( log )
Formula:          total_count ~ mean_temp * lwd_duration + (1 | year)
Data: df

     AIC      BIC   logLik deviance df.resid 
   260.6    270.7   -124.3    248.6       34 

Random effects:

Conditional model:
 Groups Name        Variance  Std.Dev. 
 year   (Intercept) 3.683e-09 6.069e-05
Number of obs: 40, groups:  year, 4

Dispersion parameter for nbinom1 family ():  178 

Conditional model:
                        Estimate Std. Error z value Pr(>|z|)   
(Intercept)             9.609769   3.276724   2.933  0.00336 **
mean_temp              -0.341435   0.177426  -1.924  0.05431 . 
lwd_duration           -0.105528   0.039145  -2.696  0.00702 **
mean_temp:lwd_duration  0.005819   0.002078   2.801  0.00510 **
---
Signif. codes:  0 ‘***’ 0.001 ‘**’ 0.01 ‘*’ 0.05 ‘.’ 0.1 ‘ ’ 1

Reproducible example:

df <- structure(list(year = structure(c(1L, 1L, 1L, 1L, 1L, 1L, 1L, 
                                        1L, 1L, 1L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 2L, 3L, 3L, 3L, 
                                        3L, 3L, 3L, 3L, 3L, 3L, 3L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 4L, 
                                        4L), levels = c("2017", "2016", "2015", "2014"), class = "factor"), 
                     mean_temp = c(10.31, 10.31, 11.35, 11.35, 14.05, 14.05, 15.96, 
                                   15.96, 16.73, 16.73, 20.92, 20.92, 21.89, 21.89, 21.48, 21.48, 
                                   25.82, 25.82, 21.06, 21.06, 16.3, 16.3, 20.16, 20.16, 19.85, 
                                   19.85, 25.45, 25.45, 24.32, 24.32, 20.59, 20.78, 20.78, 20.78, 
                                   22.23, 22.23, 22.23, 19.71, 19.71, 19.71), lwd_duration = c(116.53, 
                                                                                               116.53, 146.18, 146.18, 184.48, 184.48, 67.3, 67.3, 70.08, 
                                                                                               70.08, 71.43, 71.43, 68.58, 68.58, 91.6, 91.6, 72.07, 72.07, 
                                                                                               59.57, 59.57, 13.02, 13.02, 75.33, 75.33, 60.07, 60.07, 98.08, 
                                                                                               98.08, 71.78, 71.78, 40.25, 54.5, 54.5, 54.5, 32.47, 32.47, 
                                                                                               32.47, 61.73, 61.73, 61.73), total_count = c(0L, 0L, 0L, 
                                                                                                                                            0L, 0L, 0L, 0L, 11L, 0L, 0L, 4L, 2L, 6L, 6L, 11L, 1L, 12L, 
                                                                                                                                            0L, 1L, 2L, 2L, 2L, 18L, 362L, 135L, 684L, 34L, 123L, 21L, 
                                                                                                                                            6L, 0L, 3L, 0L, 0L, 0L, 0L, 4L, 4L, 3L, 13L)), class = c("grouped_df", 
                                                                                                                                                                                                     "tbl_df", "tbl", "data.frame"), row.names = c(NA, -40L), groups = structure(list(
                                                                                                                                                                                                       year = structure(1:4, levels = c("2017", "2016", "2015", 
                                                                                                                                                                                                                                        "2014"), class = "factor"), .rows = structure(list(1:10, 
                                                                                                                                                                                                                                                                                           11:20, 21:30, 31:40), ptype = integer(0), class = c("vctrs_list_of", 
                                                                                                                                                                                                                                                                                                                                               "vctrs_vctr", "list"))), row.names = c(NA, -4L), .drop = TRUE, class = c("tbl_df", 
                                                                                                                                                                                                                                                                                                                                                                                                                        "tbl", "data. Frame")))

Solution

  • The typical way to do this from scratch would be to create a data frame of all the combinations of your predictors, then run predict on it:

    library(glmmTMB)
    library(ggplot2)
    
    model <- glmmTMB(total_count ~ mean_temp*lwd_duration + (1|year),
                     family = nbinom1, data = df)
    model
    
    pred_df <- expand.grid(mean_temp = seq(10, 25, len = 100),
                           lwd_duration = seq(10, 200, 0.1),
                           year = factor(2017, 2017:2014))
    
    pred_df$total_count <- predict(model, newdata = pred_df, type = "response")
    

    There are various ways to plot the result, but choosing one variable for the color scale and one for the x axis gives an aesthetically pleasing and fairly intuitive output. You could show temperature on the x axis:

    ggplot(pred_df, aes(mean_temp, total_count, color = lwd_duration,
                        group = lwd_duration)) +
      geom_line(alpha = 0.1) +
      scale_color_distiller(palette = "Spectral") +
      coord_cartesian(ylim = range(df$total_count), xlim = range(df$mean_temp)) +
      theme_minimal(base_size = 22)
    

    enter image description here

    Or use temperature for color:

    ggplot(pred_df, aes(lwd_duration, total_count, color = mean_temp,
                        group = mean_temp)) +
      geom_line() +
      scale_color_distiller(palette = "RdBu") +
      coord_cartesian(ylim = range(df$total_count), xlim = range(df$lwd_duration)) +
      theme_minimal(base_size = 22)
    

    enter image description here

    Although the plots look good, I suspect they are a bit over-fit if you plot the actual data alongside them.


    If you want to use ggpredict, you could do something like:

    ggpredict(model, terms = c("mean_temp[10:27]", "lwd_duration[10:200]"),
              type = "random") %>%
      as.data.frame() %>%
      rename(mean_temp = x, lwd_duration = group, total_count = predicted) %>%
      mutate(lwd_duration = as.numeric(as.character(lwd_duration))) %>%
      ggplot(aes(mean_temp, total_count, color = lwd_duration, 
                 group = lwd_duration)) + 
      geom_line() +
      geom_point(data = df, shape = 21, aes(fill = lwd_duration), color = "black",
                 size = 2.5) +
      scale_color_distiller(palette = "Spectral") +
      scale_fill_distiller(palette = "Spectral", guide = "none") +
      coord_cartesian(ylim = range(df$total_count), xlim = range(df$mean_temp)) +
      scale_y_continuous(trans = "log1p") +
      theme_minimal(base_size = 22)
    

    enter image description here

    Note that you can edit the range and density of the terms in the square brackets after the term name (see Details in the help file). You might only be able to use a few levels of each predictor, and you can only plot a maximum of three variables - the other variables will be held steady.