rtidymodels

Tidymodels How to extract tuned parameter with workflow map and collect_metrics?


Hello I noticed that collect_metrics function does not return the parameter of interest in this case Cost Complexity. I am not sure why this is the case. I created a small example with the iris dataset to show this.

library(tidymodels)
library(baguette)

set.seed(222)


# Put 3/4 of the data into the training set 
data_split <- initial_split(iris, prop = 3/4)

# Create data frames for the two sets:
train_data <- training(data_split)
test_data  <- testing(data_split)



# Create Recipes ----

Sepal_length_rec <- recipe(Sepal.Length ~ ., data = train_data)

Petal.Length_rec <- recipe(Petal.Length ~ ., data = train_data)


## Create Model -----

rpart_mod <- bag_tree() |> 
  set_engine("rpart") |> 
  set_mode("regression") |> 
  set_args(times = 10,
           cost_complexity = tune("cost_complexity"))

grid = tibble(cost_complexity = as.numeric(seq(from = 0, to = .01, by = .01)))

## Create workflow ----

model_set <- 
  workflow_set(
    preproc = list(Sepal = Sepal_length_rec, 
                   Petal = Petal.Length_rec),
    models = list(rpart_model = rpart_mod),
    cross = TRUE)

# Tuning -----
train_resamples <- vfold_cv(train_data, v = 2)

model_set <- model_set |> 
  workflow_map("tune_grid", resamples = train_resamples, grid = grid, verbose = TRUE)
#> i 1 of 2 tuning:     Sepal_rpart_model
#> ✔ 1 of 2 tuning:     Sepal_rpart_model (1.4s)
#> i 2 of 2 tuning:     Petal_rpart_model
#> ✔ 2 of 2 tuning:     Petal_rpart_model (1.1s)

Here the cost complexity column is not included.

collect_metrics(model_set) #Cost complexity not included
#> # A tibble: 8 × 9
#>   wflow_id          .config preproc model .metric .estimator  mean     n std_err
#>   <chr>             <chr>   <chr>   <chr> <chr>   <chr>      <dbl> <int>   <dbl>
#> 1 Sepal_rpart_model Prepro… recipe  bag_… rmse    standard   0.358     2 4.23e-2
#> 2 Sepal_rpart_model Prepro… recipe  bag_… rsq     standard   0.798     2 1.91e-2
#> 3 Sepal_rpart_model Prepro… recipe  bag_… rmse    standard   0.356     2 2.57e-2
#> 4 Sepal_rpart_model Prepro… recipe  bag_… rsq     standard   0.803     2 5.35e-4
#> 5 Petal_rpart_model Prepro… recipe  bag_… rmse    standard   0.301     2 6.40e-2
#> 6 Petal_rpart_model Prepro… recipe  bag_… rsq     standard   0.968     2 1.42e-2
#> 7 Petal_rpart_model Prepro… recipe  bag_… rmse    standard   0.340     2 6.45e-2
#> 8 Petal_rpart_model Prepro… recipe  bag_… rsq     standard   0.961     2 1.84e-2

Though you can see inside the results that the cost complexity data is present

model_set$result[[1]]$.metrics
#> [[1]]
#> # A tibble: 4 × 5
#>   cost_complexity .metric .estimator .estimate .config             
#>             <dbl> <chr>   <chr>          <dbl> <chr>               
#> 1            0    rmse    standard       0.315 Preprocessor1_Model1
#> 2            0    rsq     standard       0.817 Preprocessor1_Model1
#> 3            0.01 rmse    standard       0.330 Preprocessor1_Model2
#> 4            0.01 rsq     standard       0.803 Preprocessor1_Model2
#> 
#> [[2]]
#> # A tibble: 4 × 5
#>   cost_complexity .metric .estimator .estimate .config             
#>             <dbl> <chr>   <chr>          <dbl> <chr>               
#> 1            0    rmse    standard       0.400 Preprocessor1_Model1
#> 2            0    rsq     standard       0.779 Preprocessor1_Model1
#> 3            0.01 rmse    standard       0.382 Preprocessor1_Model2
#> 4            0.01 rsq     standard       0.804 Preprocessor1_Model2

Does anyone know a way for me to extract out the tuned parameter?


Solution

  • We don't return the parameter values because we can't guarantee that the parameter set will be the same across workflows.

    You can use extract_workflow_set_result() to get the ones that you want and manipulate them:

    > model_set %>% 
    +     # The next line gets you the tune_results object
    +     extract_workflow_set_result("Sepal_rpart_model") %>% 
    +     # do whatever you would normally do. 
    +     select_best(metric = "rmse")
    # A tibble: 1 × 2
      cost_complexity .config             
                <dbl> <chr>               
    1            0.01 Preprocessor1_Model2