regressionsplinegammgcvgratia

How to overlay basis functions onto GAM plot


I am trying to find a way to replicate a plot similar to this, where the splines and the basis functions that make up those splines are both plotted in the same window.

enter image description here

I have successfully done both separately below:

#### Load Libraries ####
library(mgcv)
library(tidyverse)
library(gratia)
library(gamair)
library(ggpubr)

#### Set Theme ####
theme_set(theme_bw())

#### Add Data ####
data("wesdr")
wes <- as_tibble(wesdr)
wes

#### Fit GAM ####
fit <- gam(
  ret ~ s(dur, bs = "cr"),
  method = "REML",
  family = binomial,
  data = wes
)

#### Plot Basis Functions ####
b <- draw(basis(fit))
s <- draw(fit)
ggarrange(b,s)

enter image description here

However, I'm not sure how to mash them together. Simply throwing them on top of each other obviously doesn't work:

#### Attempt at Plotting BF and Spline ####
wes %>% 
  ggplot(aes(x=dur,
             y=ret))+
  stat_smooth(method = "gam",
              method.args = list(family = binomial),
              formula = y ~ s(x, bs = "cr"),
              se = T,
              color = "steelblue")+
  geom_line(data = basis(fit),
            aes(x=dur,
                y=value,
                color=bf))

enter image description here

How can one achieve this?


Solution

  • The figure isn't really showing or using any response data, only values of the spline covariate and it doesn't really need that unless you want pretty, smooth, basis functions. It's a different matter if you want to draw the basis for an estimated spline. Assuming you might want both (the first for teaching or explaining how splines work, the second to explain it in terms of a specific fit), below I show how to generate both kinds of figure.

    option 1, using a basis and user-specified weights

    library("mgcv")
    library("gratia")
    library("dplyr")
    
    df <- data.frame(x = seq(0, 1, length = 100))
    bs <- basis(s(x, bs = "bs", k = 10), data = df)
    
    # let's weight the basis functions (simulating model coefs)
    set.seed(1)
    betas <- data.frame(bf = factor(1:10), beta = rnorm(10))
    
    # we need to merge the weights for each basis function with the basis object
    bs <- bs |>
        left_join(betas, by = join_by("bf" == "bf")) |>
        mutate(value_w = value * beta)
    
    # now we want to sum the weighted basis functions for each value of `x`
    spl <- bs |>
        group_by(x) |>
        summarise(spline = sum(value_w))
    
    # now plot
    bs |> 
        ggplot(aes(x = x, y = value_w, colour = bf, group = bf)) +
        geom_line(show.legend = FALSE) +
        geom_line(aes(x = x, y = spline), data = spl, linewidth = 1.5,
                  inherit.aes = FALSE) +
        labs(y = expression(f(x)), x = "x")
    

    This produces:

    enter image description here

    option 2, using an estimated model

    If you want to do this for an actual model fit, you could follow the above example, but you would need to include the identifiability constraints in the spline (see ?basis) and extract the correct weights for the basis functions from the vector of model coefficients returned by coef(m).

    {gratia}'s basis() has a method for fitted models, which automates this process.

    dat <- data_sim("eg1", seed = 4)
    m <- gam(y ~ s(x0) + s(x1) + s(x2, bs = "bs") + s(x3),
             data = dat, method = "REML")
    
    # data to evaluate the basis at
    # using the CRAN version of {gratia}, we need `m`
    ds <- data_slice(m, x2 = evenly(x2, n = 200))
    # from 0.9.0 (or current GitHub version) you can do
    # ds <- data_slice(dat, x2 = evenly(x2, n = 200))
    
    # generate a tidy representation of the fitted basis functions
    x2_bs <- basis(m, term = "s(x2)", data = ds)
    
    # compute values of the spline by summing basis functions at each x2
    x2_spl <- x2_bs |>
        group_by(x2) |>
        summarise(spline = sum(value))
    
    # now plot
    x2_bs |> 
        ggplot(aes(x = x2, y = value, colour = bf, group = bf)) +
        geom_line(show.legend = FALSE) +
        geom_line(aes(x = x2, y = spline), data = x2_spl, linewidth = 1.5,
                  inherit.aes = FALSE) +
        labs(y = expression(f(x2)), x = "x2")
    

    This produces

    enter image description here

    To get the final version you wanted (with a credible interval), evaluate the spline at the same covariate values using smooth_estimates() instead of manually summing the basis functions:

    # evaluate the spline at the same values as we evaluated the basis functions
    x2_sm <- smooth_estimates(m, "s(x2)", data = ds) |>
        add_confint()
    
    # now plot
    x2_bs |> 
        ggplot(aes(x = x2, y = value, colour = bf, group = bf)) +
        geom_line(show.legend = FALSE) +
        geom_ribbon(aes(x = x2, ymin = lower_ci, ymax = upper_ci),
                    data = x2_sm, # <---- new !
                    inherit.aes = FALSE, alpha = 0.2) +
        geom_line(aes(x = x2, y = est), data = x2_sm, # <---- new !
                  linewidth = 1.5, inherit.aes = FALSE) +
        labs(y = expression(f(x2)), x = "x2")
    

    which produces

    enter image description here

    Where were you going wrong?

    I think your approach didn't work for a couple of reasons.

    1. draw() methods don't return the underlying data. By design (because of how ggplot() works) they return ggplot objects. It is better to use functions (like you did with basis()) to get the outputs you want, and then plot them yourself using ggplot(), as I showed with the last example in Option 2.

    2. Don't ever use geom_smooth() or stat_smooth() to fit a GAM. It's easy to make mistakes; here you forget to ask for method = "REML", which you need to do through method_args = list(method = "REML") in the stat_smooth() call.

    You approach isn't too wrong though; notice that many of the basis functions on the left of the figure are negative so they pull the fitted spline down, even though some of the other basis functions peak above the fitted spline.

    One final comment; use the {patchwork} package to arrange objects returns by draw() as you'll get better alignment.

    library("patchwork")
    b + s + plot_layout(ncol = 2)
    

    draw.gam() and many other draw() methods in {gratia} already return patchworks, not simple ggplot objects, so you'll get the best compatibility if you use {patchwork}'s layout tools.