rtidymodelsdalex

How do I extract a model fit from a tidymodels workflowset?


I am trying to learn tidymodels and DALEXtra.... I have successfully built a set of models with workflow_map:

grid_results <-
   all_workflows %>%
   workflow_map(
      seed = 1503,
      resamples = the_folds,
      grid = 100,
      control = grid_ctrl,
      verbose=TRUE
   )

grid_results %>% 
  rank_results() %>% 
  filter(.metric == "roc_auc") %>% 
  select(model, .config, roc_auc = mean, rank) |> 
  head()

And one of my BART models looks like the "winner":

# A tibble: 6 × 4
  model        .config                roc_auc  rank
  <chr>        <chr>                    <dbl> <int>
1 bart         Preprocessor1_Model046   0.656     1

I would like to feed that model to DALEXtra:

library(DALEXtra)

explainer_bart <- 
  explain_tidymodels(
    x, # <--------------- what goes here?
    data = the_train,
    y = adherence_group,
    label = bart,
    verbose = FALSE
  )


I think the explain_tidymodels() function wants a fit model. How can I extract it from the workflow sets result?

I am a beginner. So clues for the clueless (ideally with links) would be greatly appreciated.


Solution

  • If you tuned your BART model, you need to get a fitted workflow object. That's what you give the DALEX function.

    Here's an example from a built-in workflow set that uses the Chicago data:

    library(tidymodels)
    library(DALEXtra)
    #> Loading required package: DALEX
    #> Welcome to DALEX (version: 2.4.2).
    #> Find examples and detailed introduction at: http://ema.drwhy.ai/
    #> 
    #> Attaching package: 'DALEX'
    #> The following object is masked from 'package:dplyr':
    #> 
    #>     explain
    
    # Pull out the workflow that you want
    workflow_object <- 
      chi_features_res %>% 
      extract_workflow(id = "plus_pca_lm") 
    
    # If there are tuning parameters, get the best results
    best_results <- 
      chi_features_res %>% 
      extract_workflow_set_result(id = "plus_pca_lm") %>% 
      select_best(metric = "rmse")
    
    # Update your workflow and fit: 
    fitted_workflow_object <- 
      workflow_object %>% 
      finalize_workflow(best_results) %>% 
      fit(data = Chicago)
    
    fitted_workflow_object
    #> ══ Workflow [trained] ══════════════════════════════════════════════════════════
    #> Preprocessor: Recipe
    #> Model: linear_reg()
    #> 
    #> ── Preprocessor ────────────────────────────────────────────────────────────────
    #> 5 Recipe Steps
    #> 
    #> • step_date()
    #> • step_holiday()
    #> • step_dummy()
    #> • step_zv()
    #> • step_pca()
    #> 
    #> ── Model ───────────────────────────────────────────────────────────────────────
    #> 
    #> Call:
    #> stats::lm(formula = ..y ~ ., data = data)
    #> 
    #> Coefficients:
    #>       (Intercept)           temp_min               temp           temp_max  
    #>        -6.250e+02          1.624e-02          2.639e-02          4.314e-03  
    #>       temp_change                dew           humidity           pressure  
    #>                NA         -2.514e-02          1.124e-02         -1.116e-04  
    #>   pressure_change               wind           wind_max               gust  
    #>         8.107e-02         -1.912e-02          1.091e-03         -1.107e-02  
    #>          gust_max             percip         percip_max       weather_rain  
    #>         4.430e-03         -1.127e+01         -1.470e-01         -7.882e-01  
    #>      weather_snow      weather_cloud      weather_storm    Blackhawks_Away  
    #>        -7.061e-01         -3.149e-01          5.831e-02         -1.395e-01  
    #>   Blackhawks_Home         Bulls_Away         Bulls_Home         Bears_Away  
    #>        -3.423e-02          3.554e-02          3.418e-01          3.287e-01  
    #>        Bears_Home      WhiteSox_Away      WhiteSox_Home          Cubs_Away  
    #>         2.740e-01         -4.920e-01                 NA                 NA  
    #>         Cubs_Home          date_year      date_LaborDay   date_NewYearsDay  
    #>                NA          3.121e-01          8.171e-01         -1.004e+01  
    #> date_ChristmasDay       date_dow_Mon       date_dow_Tue       date_dow_Wed  
    #>        -1.127e+01          1.244e+01          1.384e+01          1.385e+01  
    #>      date_dow_Thu       date_dow_Fri       date_dow_Sat     date_month_Feb  
    #>         1.361e+01          1.302e+01          1.435e+00          3.602e-01  
    #>    date_month_Mar     date_month_Apr     date_month_May     date_month_Jun  
    #>         6.938e-01          9.297e-01          5.221e-01          1.397e+00  
    #>    date_month_Jul     date_month_Aug     date_month_Sep     date_month_Oct  
    #>         7.532e-01          9.335e-01          9.002e-01          1.545e+00  
    #>    date_month_Nov     date_month_Dec                PC1                PC2  
    #>         2.633e-01         -3.567e-01          6.636e-04          1.461e-01  
    #>               PC3                PC4                PC5                PC6  
    #>         4.950e-01         -1.577e-01         -4.550e-02          4.059e-01  
    #>               PC7                PC8                PC9  
    #>        -1.665e-01         -6.379e-02          2.689e-01
    
    # Put that in the explainer
    # There are some warnings here but you can disragard them
    explainer_obj <- 
      explain_tidymodels(
        fitted_workflow_object, 
        data = Chicago %>% select(-ridership),
        y = Chicago$ridership,
        label = "model",
        verbose = FALSE
      )
    #> Warning in predict.lm(object = object$fit, newdata = new_data, type =
    #> "response"): prediction from a rank-deficient fit may be misleading
    
    #> Warning in predict.lm(object = object$fit, newdata = new_data, type =
    #> "response"): prediction from a rank-deficient fit may be misleading
    
    #> Warning in predict.lm(object = object$fit, newdata = new_data, type =
    #> "response"): prediction from a rank-deficient fit may be misleading
    

    Created on 2022-10-11 by the reprex package (v2.0.1)