I am trying to learn tidymodels and DALEXtra.... I have successfully built a set of models with workflow_map:
grid_results <-
all_workflows %>%
workflow_map(
seed = 1503,
resamples = the_folds,
grid = 100,
control = grid_ctrl,
verbose=TRUE
)
grid_results %>%
rank_results() %>%
filter(.metric == "roc_auc") %>%
select(model, .config, roc_auc = mean, rank) |>
head()
And one of my BART models looks like the "winner":
# A tibble: 6 × 4
model .config roc_auc rank
<chr> <chr> <dbl> <int>
1 bart Preprocessor1_Model046 0.656 1
I would like to feed that model to DALEXtra:
library(DALEXtra)
explainer_bart <-
explain_tidymodels(
x, # <--------------- what goes here?
data = the_train,
y = adherence_group,
label = bart,
verbose = FALSE
)
I think the explain_tidymodels()
function wants a fit model. How can I extract it from the workflow sets result?
I am a beginner. So clues for the clueless (ideally with links) would be greatly appreciated.
If you tuned your BART model, you need to get a fitted workflow object. That's what you give the DALEX function.
Here's an example from a built-in workflow set that uses the Chicago
data:
library(tidymodels)
library(DALEXtra)
#> Loading required package: DALEX
#> Welcome to DALEX (version: 2.4.2).
#> Find examples and detailed introduction at: http://ema.drwhy.ai/
#>
#> Attaching package: 'DALEX'
#> The following object is masked from 'package:dplyr':
#>
#> explain
# Pull out the workflow that you want
workflow_object <-
chi_features_res %>%
extract_workflow(id = "plus_pca_lm")
# If there are tuning parameters, get the best results
best_results <-
chi_features_res %>%
extract_workflow_set_result(id = "plus_pca_lm") %>%
select_best(metric = "rmse")
# Update your workflow and fit:
fitted_workflow_object <-
workflow_object %>%
finalize_workflow(best_results) %>%
fit(data = Chicago)
fitted_workflow_object
#> ══ Workflow [trained] ══════════════════════════════════════════════════════════
#> Preprocessor: Recipe
#> Model: linear_reg()
#>
#> ── Preprocessor ────────────────────────────────────────────────────────────────
#> 5 Recipe Steps
#>
#> • step_date()
#> • step_holiday()
#> • step_dummy()
#> • step_zv()
#> • step_pca()
#>
#> ── Model ───────────────────────────────────────────────────────────────────────
#>
#> Call:
#> stats::lm(formula = ..y ~ ., data = data)
#>
#> Coefficients:
#> (Intercept) temp_min temp temp_max
#> -6.250e+02 1.624e-02 2.639e-02 4.314e-03
#> temp_change dew humidity pressure
#> NA -2.514e-02 1.124e-02 -1.116e-04
#> pressure_change wind wind_max gust
#> 8.107e-02 -1.912e-02 1.091e-03 -1.107e-02
#> gust_max percip percip_max weather_rain
#> 4.430e-03 -1.127e+01 -1.470e-01 -7.882e-01
#> weather_snow weather_cloud weather_storm Blackhawks_Away
#> -7.061e-01 -3.149e-01 5.831e-02 -1.395e-01
#> Blackhawks_Home Bulls_Away Bulls_Home Bears_Away
#> -3.423e-02 3.554e-02 3.418e-01 3.287e-01
#> Bears_Home WhiteSox_Away WhiteSox_Home Cubs_Away
#> 2.740e-01 -4.920e-01 NA NA
#> Cubs_Home date_year date_LaborDay date_NewYearsDay
#> NA 3.121e-01 8.171e-01 -1.004e+01
#> date_ChristmasDay date_dow_Mon date_dow_Tue date_dow_Wed
#> -1.127e+01 1.244e+01 1.384e+01 1.385e+01
#> date_dow_Thu date_dow_Fri date_dow_Sat date_month_Feb
#> 1.361e+01 1.302e+01 1.435e+00 3.602e-01
#> date_month_Mar date_month_Apr date_month_May date_month_Jun
#> 6.938e-01 9.297e-01 5.221e-01 1.397e+00
#> date_month_Jul date_month_Aug date_month_Sep date_month_Oct
#> 7.532e-01 9.335e-01 9.002e-01 1.545e+00
#> date_month_Nov date_month_Dec PC1 PC2
#> 2.633e-01 -3.567e-01 6.636e-04 1.461e-01
#> PC3 PC4 PC5 PC6
#> 4.950e-01 -1.577e-01 -4.550e-02 4.059e-01
#> PC7 PC8 PC9
#> -1.665e-01 -6.379e-02 2.689e-01
# Put that in the explainer
# There are some warnings here but you can disragard them
explainer_obj <-
explain_tidymodels(
fitted_workflow_object,
data = Chicago %>% select(-ridership),
y = Chicago$ridership,
label = "model",
verbose = FALSE
)
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from a rank-deficient fit may be misleading
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from a rank-deficient fit may be misleading
#> Warning in predict.lm(object = object$fit, newdata = new_data, type =
#> "response"): prediction from a rank-deficient fit may be misleading
Created on 2022-10-11 by the reprex package (v2.0.1)