Are there ways to access the component functions of a generalised additive model (GAM) fitted using the mgcv library?
Say I fit a gam as follows
library(mgcv)
md = gam(y ~ ti(x1) + ti(x2) + ti(x1,x2))
For further analysis, I would like to evaluate each of the components ti(x1)
, ti(x2)
and ti(x1,x2)
separately, on their own. I know it is possible in principle, since for example plot.gam
can plot each component separately. But I could find no indication in the help-files how to access those components. Did I overlook something, or is my only way forward, parsing the sources for plot.gam?
EDIT
"Access" means that I am able to mimic predict
for each component. That means, if my model looks like
$$ f(x)=\sum_{i=1}^N f_i(x)$$ then I would like to calculate $f_i(x)$ on its own for any suitable input $x$. This would include the case that single $f_i$ are multivariate tensors, i.e. they look like $f_i= ti(x_1, \ldots, x_d).$
In addition, it would be great if I could also evaluate the gradients $\nabla f_i$ analytically, i.e. without recourse to numeric approximation of the derivatives.
I'm biased, of course, but if you don't mind using tidyverse oriented tools, then my gratia package makes doing everything you want pretty simple.
Here's an example
library("gratia")
library("mgcv")
df <- data_sim("eg1", seed = 42)
m <- gam(y ~ s(x0) + s(x1) + s(x2) + s(x3), data = df, method = "REML")
We can now evaluate the estimated smooth functions at specified values of the covariates. If you don't specify the values, evenly spaced values over the range of each covariate will be generated for you
sm <- smooth_estimates(m)
This sm
is a tibble, where est
is the estimated function at the specified (or generated) values of the covariate(s), as indicated in the other columns
> sm
# A tibble: 400 × 9
smooth type by est se x0 x1 x2 x3
<chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 s(x0) TPRS NA -1.32 0.390 0.000239 NA NA NA
2 s(x0) TPRS NA -1.24 0.365 0.0103 NA NA NA
3 s(x0) TPRS NA -1.17 0.340 0.0204 NA NA NA
4 s(x0) TPRS NA -1.09 0.318 0.0304 NA NA NA
5 s(x0) TPRS NA -1.02 0.297 0.0405 NA NA NA
6 s(x0) TPRS NA -0.947 0.279 0.0506 NA NA NA
7 s(x0) TPRS NA -0.875 0.263 0.0606 NA NA NA
8 s(x0) TPRS NA -0.803 0.249 0.0707 NA NA NA
9 s(x0) TPRS NA -0.732 0.237 0.0807 NA NA NA
10 s(x0) TPRS NA -0.662 0.228 0.0908 NA NA NA
# ℹ 390 more rows
# ℹ Use `print(n = ...)` to see more rows
If you want to do this for the observed data, just pass in that data
sm2 <- smooth_estimates(m, data = df)
You can plot the estimated functions using the draw()
method:
sm |> draw()
If you want to estimate derivatives of each function on the linear predictor scale (it makes no difference for the model here as it is Gaussian with an identity link function, but it does for non-identify link models) you can use the derivatives()
function:
fd <- derivatives(m, type = "central")
which given the defaults computes first order derivatives via finite differences and a typical frequentist confidence interval:
> fd
# A tibble: 400 × 10
smooth var by_var fs_var data derivative se crit lower upper
<chr> <chr> <chr> <chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 s(x0) x0 NA NA 0.000239 7.41 3.33 1.96 0.874 13.9
2 s(x0) x0 NA NA 0.0103 7.40 3.33 1.96 0.884 13.9
3 s(x0) x0 NA NA 0.0204 7.39 3.30 1.96 0.929 13.8
4 s(x0) x0 NA NA 0.0304 7.36 3.24 1.96 1.01 13.7
5 s(x0) x0 NA NA 0.0405 7.32 3.15 1.96 1.14 13.5
6 s(x0) x0 NA NA 0.0506 7.26 3.04 1.96 1.30 13.2
7 s(x0) x0 NA NA 0.0606 7.18 2.90 1.96 1.49 12.9
8 s(x0) x0 NA NA 0.0707 7.09 2.76 1.96 1.69 12.5
9 s(x0) x0 NA NA 0.0807 6.99 2.61 1.96 1.87 12.1
10 s(x0) x0 NA NA 0.0908 6.87 2.47 1.96 2.03 11.7
# ℹ 390 more rows
# ℹ Use `print(n = ...)` to see more rows
see ?gratia::derivatives
for more.
If you want the weighted basis functions, using the basis()
function
bs <- basis(m)
which returns
> bs
# A tibble: 3,600 × 9
smooth type by_variable bf value x0 x1 x2 x3
<chr> <chr> <chr> <fct> <dbl> <dbl> <dbl> <dbl> <dbl>
1 s(x0) TPRS NA 1 -0.0818 0.000239 NA NA NA
2 s(x0) TPRS NA 2 0.699 0.000239 NA NA NA
3 s(x0) TPRS NA 3 -0.112 0.000239 NA NA NA
4 s(x0) TPRS NA 4 0.262 0.000239 NA NA NA
5 s(x0) TPRS NA 5 -0.448 0.000239 NA NA NA
6 s(x0) TPRS NA 6 0.648 0.000239 NA NA NA
7 s(x0) TPRS NA 7 -0.326 0.000239 NA NA NA
8 s(x0) TPRS NA 8 -1.65 0.000239 NA NA NA
9 s(x0) TPRS NA 9 -0.308 0.000239 NA NA NA
10 s(x0) TPRS NA 1 -0.0818 0.0103 NA NA NA
# ℹ 3,590 more rows
# ℹ Use `print(n = ...)` to see more rows