I'm a bit confused about getting metrics from resamples using tidymodels.
I seem to be getting 3 different metrics from the same set of resamples, depending on if I use collect_predictions() %>% metrics() or simply collect_metrics()
Here is a simple example...
library(tidyverse)
library(tidymodels)
starwars_df <- starwars %>% select(name:sex) %>% drop_na()
lasso_linear_reg_glmnet_spec <-
linear_reg(penalty = .1, mixture = 1) %>%
set_engine('glmnet')
basic_rec <-
recipe(mass ~ height + sex + skin_color,
data = starwars_df) %>%
step_novel(all_nominal_predictors()) %>%
step_other(all_nominal_predictors()) %>%
step_dummy(all_nominal_predictors()) %>%
step_nzv(all_predictors())
sw_wf <- workflow() %>%
add_recipe(basic_rec) %>%
add_model(lasso_linear_reg_glmnet_spec)
sw_boots <- bootstraps(starwars_df, times = 50)
resampd <- fit_resamples(
sw_wf,
sw_boots,
control = control_resamples(save_pred = TRUE)
)
The following three lines give different results
resampd %>% collect_predictions(resampd, summarize = T) %>% metrics(mass, .pred)
resampd %>% collect_predictions(resampd, summarize = F) %>% metrics(mass, .pred)
resampd %>% collect_metrics()
As an additional question, what would be the best/correct way to get confidence intervals for the rmse in the above example. Here is one way...
individ_metrics <- resampd %>% collect_predictions() %>% group_by(id) %>% rmse(mass, .pred)
confintr::ci_mean(individ_metrics$.estimate)
mean(individ_metrics$.estimate)
Thanks!
The reason that none of those are the same is they are not aggregated in the same way. It turns that taking a mean of a set of means doesn't give you the same (right) result as taking the mean of the whole underlying set. If you were to do something like resampd %>% collect_predictions(summarize = TRUE) %>% metrics(mass, .pred)
, that is like taking a mean of a set of means.
It turns out that these two things are the same:
## these are the same:
resampd %>%
collect_predictions(summarize = FALSE) %>%
group_by(id) %>%
metrics(mass, .pred)
#> # A tibble: 150 × 4
#> id .metric .estimator .estimate
#> <chr> <chr> <chr> <dbl>
#> 1 Bootstrap01 rmse standard 16.4
#> 2 Bootstrap02 rmse standard 23.1
#> 3 Bootstrap03 rmse standard 31.6
#> 4 Bootstrap04 rmse standard 17.6
#> 5 Bootstrap05 rmse standard 9.59
#> 6 Bootstrap06 rmse standard 25.0
#> 7 Bootstrap07 rmse standard 16.3
#> 8 Bootstrap08 rmse standard 35.1
#> 9 Bootstrap09 rmse standard 25.7
#> 10 Bootstrap10 rmse standard 25.3
#> # … with 140 more rows
resampd %>% collect_metrics(summarize = FALSE)
#> # A tibble: 100 × 5
#> id .metric .estimator .estimate .config
#> <chr> <chr> <chr> <dbl> <chr>
#> 1 Bootstrap01 rmse standard 16.4 Preprocessor1_Model1
#> 2 Bootstrap01 rsq standard 0.799 Preprocessor1_Model1
#> 3 Bootstrap02 rmse standard 23.1 Preprocessor1_Model1
#> 4 Bootstrap02 rsq standard 0.193 Preprocessor1_Model1
#> 5 Bootstrap03 rmse standard 31.6 Preprocessor1_Model1
#> 6 Bootstrap03 rsq standard 0.608 Preprocessor1_Model1
#> 7 Bootstrap04 rmse standard 17.6 Preprocessor1_Model1
#> 8 Bootstrap04 rsq standard 0.836 Preprocessor1_Model1
#> 9 Bootstrap05 rmse standard 9.59 Preprocessor1_Model1
#> 10 Bootstrap05 rsq standard 0.860 Preprocessor1_Model1
#> # … with 90 more rows
Created on 2022-08-23 with reprex v2.0.2