rpurrrtidymodelsbroom

broom.mixed::augment does no work with sample::analysis


I have been trying to run 10-fold cross-validation using some of the tidymodels tools, and while everything looks OK, I am having trouble expanding the model using the augment function. The error seems to be linked to the dataset splitting by rsample::analysis() and that augment does not work on objects of class: function.

I can see that the predictions are there, e.g., cv_models[[3]][[1]]$fitted and that the folds also exist, e.g., split_folds$splits[[1]] %>% analysis(), but augment fails.

Any clues would be much appreciated!

library(tidymodels); library(nlme); library(broom.mixed)
#> Warning: package 'scales' was built under R version 4.3.1
#> Warning: package 'dplyr' was built under R version 4.3.1
#> Warning: package 'ggplot2' was built under R version 4.3.1
#> Warning: package 'modeldata' was built under R version 4.3.1
#> Warning: package 'recipes' was built under R version 4.3.1
#> Warning: package 'yardstick' was built under R version 4.3.1
#> 
#> Attaching package: 'nlme'
#> The following object is masked from 'package:dplyr':
#> 
#>     collapse

data <-  
  tibble(
    co2_mean_tot = rlnorm(n = 100, meanlog = 0, sdlog = 1),
    WALA = rlnorm(n = 100, meanlog = 0, sdlog = 1),
         precip_av = rnorm(n = 100, mean = 400, sd = co2_mean_tot),
         soil_c_kgC = rnorm(n = 100, mean = 40, sd = 5),
         area = rlnorm(n = 100, meanlog = 0, sdlog = 1),
         long = seq(100,180, length.out = 100))


split_folds <- vfold_cv(data, strata = long, v = 10)

fit_gls_cv <- 
  function(split) {
    gls(log10(co2_mean_tot) ~  scale(log10(WALA))*scale(precip_av)*scale(soil_c_kgC) + scale(log10(area)),
        weights = varExp(form = ~ precip_av), 
        data = rsample::analysis(split), method = "REML")
  }

cv_models <- 
  split_folds %>% 
  mutate(model = map(splits, fit_gls_cv),
         coef_info = map(model, tidy),
         gl_info = map(model, glance),
         pred = map(model, predict)) 

cv_models %>% 
  mutate(aug = map(model, augment))
#> Error in `mutate()`:
#> ℹ In argument: `aug = map(model, augment)`.
#> Caused by error in `map()`:
#> ℹ In index: 1.
#> Caused by error in `rsample::analysis()`:
#> ! No method for objects of class: function
#> Backtrace:
#>      ▆
#>   1. ├─cv_models %>% mutate(aug = map(model, augment))
#>   2. ├─dplyr::mutate(., aug = map(model, augment))
#>   3. ├─dplyr:::mutate.data.frame(., aug = map(model, augment))
#>   4. │ └─dplyr:::mutate_cols(.data, dplyr_quosures(...), by)
#>   5. │   ├─base::withCallingHandlers(...)
#>   6. │   └─dplyr:::mutate_col(dots[[i]], data, mask, new_columns)
#>   7. │     └─mask$eval_all_mutate(quo)
#>   8. │       └─dplyr (local) eval()
#>   9. └─purrr::map(model, augment)
#>  10.   └─purrr:::map_("list", .x, .f, ..., .progress = .progress)
#>  11.     ├─purrr:::with_indexed_errors(...)
#>  12.     │ └─base::withCallingHandlers(...)
#>  13.     ├─purrr:::call_with_cleanup(...)
#>  14.     ├─generics (local) .f(.x[[i]], ...)
#>  15.     └─broom.mixed:::augment.gls(.x[[i]], ...)
#>  16.       ├─broom::augment_columns(x, data, newdata, se.fit = NULL)
#>  17.       ├─nlme::getData(x)
#>  18.       └─nlme:::getData.gls(x)
#>  19.         ├─base::eval(if ("data" %in% names(object)) object$data else mCall$data)
#>  20.         │ └─base::eval(if ("data" %in% names(object)) object$data else mCall$data)
#>  21.         ├─rsample::analysis(split)
#>  22.         └─rsample:::analysis.default(split)
#>  23.           └─cli::cli_abort("No method for objects of class{?es}: {cls}")
#>  24.             └─rlang::abort(...)
Created on 2024-01-28 with reprex v2.0.2

Solution

  • Here's broom.mixed:::augment.gls:

    function (x, data = nlme::getData(x), newdata, ...)  {
        if (missing(newdata)) {
            newdata <- NULL
        }
        ret <- augment_columns(x, data, newdata, se.fit = NULL)
        ret
    }
    

    The problem occurs when nlme::getData() tries to get the data, which it tries to do by evaluating the $data component of the $call component of the model, which is rsample::analysis(split).

    This seems to work, although I haven't checked it carefully:

    cv_models %>% 
       mutate(aug = map2(model, splits, ~augment(.x, newdata = .y$data)))