SVM does not have information within it about variable importance. So, I am using permutation of the variables to generate variable importance plot taking help from this answer, like
library(vip)
library(MASS)
library(tidymodels)
data(Boston, package = "MASS")
df <- Boston
#Split the data into train and test set
set.seed(7)
splits <- initial_split(df)
train <- training(splits)
test <- testing(splits)
#Preprocess with recipe
rec <- recipe(medv~.,data=train) %>%
step_normalize(all_predictors())
svm_spec <- svm_rbf(margin = 0.0937, cost = 26.7, rbf_sigma = 0.0208) %>%
set_engine("kernlab") %>%
set_mode("regression")
#Putting into workflow
svr_fit <- workflow() %>%
add_recipe(rec) %>%
add_model(svm_spec) %>%
fit(data = train)
svr_fit %>%
pull_workflow_fit() %>%
vip(method = "permute", nsim = 5,
target = "medv", metric = "rmse",
pred_wrapper = kernlab::predict, train = train)
But it is returning me the following error:
Error in
metric_fun()
: !estimate
should be a numeric vector, not a numeric matrix.
At the same time pull_workflow_fit
is deprecated. What should be used instead of pull_workflow_fit
?
TL;DR pred_wrapper
has to be updated to return a vector not a data.frame.
The variables should not be normalized to get meaningful results.
pull_workflow_fit
is replaced with extract_fit_parsnip
library(vip)
#>
#> Attaching package: 'vip'
#> The following object is masked from 'package:utils':
#>
#> vi
library(MASS)
library(tidymodels)
data(Boston, package = "MASS")
df <- Boston
# Split the data into train and test set
set.seed(7)
splits <- initial_split(df)
train <- training(splits)
test <- testing(splits)
# Preprocess with recipe
rec <- recipe(
formula = medv ~ .,
data = train
)
svm_spec <- svm_rbf(margin = 0.0937, cost = 20, rbf_sigma = 0.0208) %>%
set_engine("kernlab") %>%
set_mode("regression")
# Putting into workflow
svr_fit <- workflow() %>%
add_recipe(rec) %>%
add_model(svm_spec) %>%
fit(data = train)
svr_fit %>%
extract_fit_parsnip() %>%
vip(
method = "permute", nsim = 5,
target = "medv", metric = "rmse",
geom = "point",
pred_wrapper = function(object, newdata) as.vector(kernlab::predict(object, newdata)),
train = train
)
Created on 2023-12-30 with reprex v2.0.2