rmachine-learningtidymodelsdalex

Remove variable in model_parts() plot


I want to remove certain variables from the plot.

# Packages 
library(tidymodels)
library(mlbench)

# Data 
data("PimaIndiansDiabetes")
dat <- PimaIndiansDiabetes 
dat$some_new_group[1:384] <- "group 1" 
dat$some_new_group[385:768] <- "group 2"

# Split
set.seed(123)
ind <- initial_split(dat)
dat_train <- training(ind)
dat_test <- testing(ind)

# Recipes
svm_rec <- 
  recipe(diabetes ~., data = dat_train) %>% 
  update_role(some_new_group, new_role = "group_var") %>% 
  step_rm(pressure) %>% 
  step_YeoJohnson(all_numeric_predictors())
    
# Model spec 
svm_spec <- 
  svm_rbf() %>% 
  set_mode("classification") %>% 
  set_engine("kernlab")

# Workflow 
svm_wf <- 
  workflow() %>% 
  add_recipe(svm_rec) %>% 
  add_model(svm_spec)

# Train
svm_trained <- 
  svm_wf %>% 
  fit(dat_train)

# Explainer
library(DALEXtra)

svm_exp <- explain_tidymodels(svm_trained, 
                              data = dat %>% select(-diabetes), 
                              y = dat$diabetes %>% as.numeric(), 
                              label = "SVM")
# Variable importance
set.seed(123)
svm_vp <- model_parts(svm_exp, type = "variable_importance") 
svm_vp

plot(svm_vp) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

enter image description here

Notice in the recipes above, I removed variable pressure and make a new categorical variable (some_new_group).

So, I can remove the variable pressure some_new_group from the plot manually like this:

plot(svm_vp %>% filter(variable != c("pressure", "some_new_group"))) +
  ggtitle("Mean-variable importance over 50 permutations", "") 

enter image description here

But, is it possible to remove the variables when I run explain_tidymodels() or model_parts()?


Solution

  • If you have variables that are not predictors or outcomes handled by your workflow() (like the variable you remove and your grouping variable), you want to make sure you only pass outcomes and predictors to explain_tidymodels(). You'll also need to build the explainer with the parsnip model, rather than the workflow() which is expecting to handle those non-outcome, non-predictor variables:

    library(tidymodels)
    
    # Data 
    data("PimaIndiansDiabetes", package = "mlbench")
    dat <- PimaIndiansDiabetes 
    dat$some_new_group[1:384] <- "group 1" 
    dat$some_new_group[385:768] <- "group 2"
    
    # Split
    set.seed(123)
    ind <- initial_split(dat)
    dat_train <- training(ind)
    dat_test <- testing(ind)
    
    # Recipes
    svm_rec <- 
      recipe(diabetes ~., data = dat_train) %>% 
      update_role(some_new_group, new_role = "group_var") %>% 
      step_rm(pressure) %>% 
      step_YeoJohnson(all_numeric_predictors())
    
    # Model spec 
    svm_spec <- 
      svm_rbf() %>% 
      set_mode("classification") %>% 
      set_engine("kernlab")
    
    # Train
    svm_trained <- 
      workflow(svm_rec, svm_spec) %>% 
      fit(dat_train)
    
    # Explainer
    library(DALEXtra)
    #> Loading required package: DALEX
    #> Welcome to DALEX (version: 2.4.0).
    #> Find examples and detailed introduction at: http://ema.drwhy.ai/
    #> 
    #> Attaching package: 'DALEX'
    #> The following object is masked from 'package:dplyr':
    #> 
    #>     explain
    
    svm_exp <- explain_tidymodels(
      extract_fit_parsnip(svm_trained), 
      data = svm_rec %>% prep() %>% bake(new_data = NULL, all_predictors()), 
      y = dat_train$diabetes %>% as.numeric(), 
      label = "SVM"
    )
    #> Preparation of a new explainer is initiated
    #>   -> model label       :  SVM 
    #>   -> data              :  576  rows  7  cols 
    #>   -> data              :  tibble converted into a data.frame 
    #>   -> target variable   :  576  values 
    #>   -> predict function  :  yhat.model_fit  will be used (  default  )
    #>   -> predicted values  :  No value for predict function target column. (  default  )
    #>   -> model_info        :  package parsnip , ver. 0.2.1 , task classification (  default  ) 
    #>   -> predicted values  :  numerical, min =  0.08057345 , mean =  0.3540662 , max =  0.9357536  
    #>   -> residual function :  difference between y and yhat (  default  )
    #>   -> residuals         :  numerical, min =  0.1083522 , mean =  0.9948921 , max =  1.895405  
    #>   A new explainer has been created!
    
    # Variable importance
    set.seed(123)
    svm_vp <- model_parts(svm_exp, type = "variable_importance") 
    svm_vp
    #>       variable mean_dropout_loss label
    #> 1 _full_model_         0.6861190   SVM
    #> 2      glucose         0.5919956   SVM
    #> 3         mass         0.6673947   SVM
    #> 4     pregnant         0.6700007   SVM
    #> 5          age         0.6701185   SVM
    #> 6     pedigree         0.6702812   SVM
    #> 7      triceps         0.6760106   SVM
    #> 8      insulin         0.6777355   SVM
    #> 9   _baseline_         0.5020752   SVM
    
    plot(svm_vp) +
      ggtitle("Mean-variable importance over 50 permutations", "") 
    

    Created on 2022-05-03 by the reprex package (v2.0.1)

    If you have these "extra" variables in your workflow that shouldn't be used for explainability, then you'll need to do some extra work and can't rely on the workflow() alone.