rtidymodelsvip

Variable importance plot for support vector machine with tidymodel framework is not working


SVM does not have information within it about variable importance. So, I am using permutation of the variables to generate variable importance plot taking help from this answer, like

library(vip)
library(MASS)
library(tidymodels)

data(Boston, package = "MASS")
df <- Boston

#Split the data into train and test set
set.seed(7)
splits <- initial_split(df)
train <- training(splits)
test <- testing(splits)


#Preprocess with recipe
rec <- recipe(medv~.,data=train) %>%
  step_normalize(all_predictors()) 

svm_spec <- svm_rbf(margin = 0.0937, cost = 26.7, rbf_sigma = 0.0208) %>%
  set_engine("kernlab") %>%
  set_mode("regression")


#Putting into workflow
svr_fit <- workflow() %>%
  add_recipe(rec) %>%
  add_model(svm_spec) %>% 
  fit(data = train)

svr_fit %>%
  pull_workflow_fit() %>%
  vip(method = "permute", nsim = 5,
      target = "medv", metric = "rmse",
      pred_wrapper = kernlab::predict, train = train)

But it is returning me the following error:

Error in metric_fun(): ! estimate should be a numeric vector, not a numeric matrix.

At the same time pull_workflow_fit is deprecated. What should be used instead of pull_workflow_fit?


Solution

  • TL;DR pred_wrapper has to be updated to return a vector not a data.frame. The variables should not be normalized to get meaningful results. pull_workflow_fit is replaced with extract_fit_parsnip

    library(vip)
    #> 
    #> Attaching package: 'vip'
    #> The following object is masked from 'package:utils':
    #> 
    #>     vi
    library(MASS)
    library(tidymodels)
    
    data(Boston, package = "MASS")
    df <- Boston
    
    # Split the data into train and test set
    set.seed(7)
    splits <- initial_split(df)
    train <- training(splits)
    test <- testing(splits)
    
    
    # Preprocess with recipe
    rec <- recipe(
      formula = medv ~ .,
      data = train
    ) 
    
    svm_spec <- svm_rbf(margin = 0.0937, cost = 20, rbf_sigma = 0.0208) %>%
      set_engine("kernlab") %>%
      set_mode("regression")
    
    
    # Putting into workflow
    svr_fit <- workflow() %>%
      add_recipe(rec) %>%
      add_model(svm_spec) %>%
      fit(data = train)
    
    svr_fit %>%
      extract_fit_parsnip() %>%
      vip(
        method = "permute", nsim = 5,
        target = "medv", metric = "rmse",
        geom = "point",
        pred_wrapper = function(object, newdata) as.vector(kernlab::predict(object, newdata)),
        train = train
      )
    

    Created on 2023-12-30 with reprex v2.0.2