tidymodelsyardstick

Tidymodels prediction methods giving different results


I'm a bit confused about getting metrics from resamples using tidymodels.

I seem to be getting 3 different metrics from the same set of resamples, depending on if I use collect_predictions() %>% metrics() or simply collect_metrics()

Here is a simple example...

library(tidyverse)
library(tidymodels)

starwars_df <- starwars %>% select(name:sex) %>% drop_na()

lasso_linear_reg_glmnet_spec <-
  linear_reg(penalty = .1, mixture = 1) %>%
  set_engine('glmnet')

basic_rec <-
  recipe(mass ~ height  + sex + skin_color,
         data = starwars_df) %>% 
  step_novel(all_nominal_predictors()) %>%
  step_other(all_nominal_predictors()) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_nzv(all_predictors())

sw_wf <- workflow() %>% 
  add_recipe(basic_rec) %>% 
  add_model(lasso_linear_reg_glmnet_spec)

sw_boots <-  bootstraps(starwars_df, times = 50)

resampd <- fit_resamples(
  sw_wf,
  sw_boots,
  control = control_resamples(save_pred = TRUE)
)

The following three lines give different results

resampd %>% collect_predictions(resampd, summarize = T) %>% metrics(mass, .pred)
resampd %>% collect_predictions(resampd, summarize = F) %>% metrics(mass, .pred)
resampd %>% collect_metrics()
 

As an additional question, what would be the best/correct way to get confidence intervals for the rmse in the above example. Here is one way...

individ_metrics <- resampd %>% collect_predictions() %>% group_by(id) %>% rmse(mass, .pred) 
confintr::ci_mean(individ_metrics$.estimate)
mean(individ_metrics$.estimate)

Thanks!


Solution

  • The reason that none of those are the same is they are not aggregated in the same way. It turns that taking a mean of a set of means doesn't give you the same (right) result as taking the mean of the whole underlying set. If you were to do something like resampd %>% collect_predictions(summarize = TRUE) %>% metrics(mass, .pred), that is like taking a mean of a set of means.

    It turns out that these two things are the same:

    ## these are the same:
    resampd %>% 
        collect_predictions(summarize = FALSE) %>% 
        group_by(id) %>% 
        metrics(mass, .pred)
    #> # A tibble: 150 × 4
    #>    id          .metric .estimator .estimate
    #>    <chr>       <chr>   <chr>          <dbl>
    #>  1 Bootstrap01 rmse    standard       16.4 
    #>  2 Bootstrap02 rmse    standard       23.1 
    #>  3 Bootstrap03 rmse    standard       31.6 
    #>  4 Bootstrap04 rmse    standard       17.6 
    #>  5 Bootstrap05 rmse    standard        9.59
    #>  6 Bootstrap06 rmse    standard       25.0 
    #>  7 Bootstrap07 rmse    standard       16.3 
    #>  8 Bootstrap08 rmse    standard       35.1 
    #>  9 Bootstrap09 rmse    standard       25.7 
    #> 10 Bootstrap10 rmse    standard       25.3 
    #> # … with 140 more rows
    resampd %>% collect_metrics(summarize = FALSE)
    #> # A tibble: 100 × 5
    #>    id          .metric .estimator .estimate .config             
    #>    <chr>       <chr>   <chr>          <dbl> <chr>               
    #>  1 Bootstrap01 rmse    standard      16.4   Preprocessor1_Model1
    #>  2 Bootstrap01 rsq     standard       0.799 Preprocessor1_Model1
    #>  3 Bootstrap02 rmse    standard      23.1   Preprocessor1_Model1
    #>  4 Bootstrap02 rsq     standard       0.193 Preprocessor1_Model1
    #>  5 Bootstrap03 rmse    standard      31.6   Preprocessor1_Model1
    #>  6 Bootstrap03 rsq     standard       0.608 Preprocessor1_Model1
    #>  7 Bootstrap04 rmse    standard      17.6   Preprocessor1_Model1
    #>  8 Bootstrap04 rsq     standard       0.836 Preprocessor1_Model1
    #>  9 Bootstrap05 rmse    standard       9.59  Preprocessor1_Model1
    #> 10 Bootstrap05 rsq     standard       0.860 Preprocessor1_Model1
    #> # … with 90 more rows
    

    Created on 2022-08-23 with reprex v2.0.2