I followed this vignette of the tidymodels
package to make the gbm
an engine for the boost_tree()
function.
I came up with the following for regression tasks:
library(tidyverse)
library(tidymodels)
library(workflow)
library(tune)
install.packages(gbm) #just install do not load
{
set_model_engine(
model = "boost_tree",
mode = "regression",
eng = "gbm"
)
set_dependency("boost_tree", eng = "gbm", pkg = "gbm")
set_model_arg(
model = "boost_tree",
eng = "gbm",
parsnip = "trees",
original = "n.trees",
func = list(pkg = "gbm", fun = "gbm"),
has_submodel = TRUE
)
set_model_arg(
model = "boost_tree",
eng = "gbm",
parsnip = "tree_depth",
original = "interaction.depth",
func = list(pkg = "gbm", fun = "gbm"),
has_submodel = TRUE
)
set_model_arg(
model = "boost_tree",
eng = "gbm",
parsnip = "learn_rate",
original = "shrinkage",
func = list(pkg = "gbm", fun = "gbm"),
has_submodel = TRUE
)
set_encoding(
model = "boost_tree",
eng = "gbm",
mode = "regression",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)
gbm <- function(mode = "regression", trees = NULL,
tree_depth = NULL, learn_rate = NULL, cv.folds = 1) {
# make sure mode is regression
if(mode != "regression") {
stop("`mode` should be 'regression'", call. = FALSE)
}
# capture argument in quosures
args <- list(
trees = rlang::enquo(trees),
tree_depth = rlang::enquo(tree_depth),
learn_rate = rlang::enquo(learn_rate)
)
# empty slots for future parts of specification
out <- list(args = args, eng_args = NULL,
mode = mode, method = NULL, engine = NULL)
# set class in correct order
class(out) <- make_classes("gbm")
out
}
set_fit(
model = "boost_tree",
eng = "gbm",
mode = "regression",
value = list(
interface = "formula", # other possible values are "data.frame", "matrix"
protect = c("formula", "data"), # nonchangeable user-arguments
func = c(pkg = "gbm", fun = "gbm"), # call gbm::gbm()
defaults = list(
distribution = "gaussian",
n.cores = NULL,
verbose = FALSE
) # default argument changeable by user
)
)
set_pred(
model = "boost_tree",
eng = "gbm",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = expr(object$fit),
newdata = expr(new_data),
n.trees = expr(object$fit$n.trees),
type = "response",
single.tree = TRUE
)
)
)
}
But if I try to use this engine to tune the hyperparameters using tune_bayes()
from the parsnip
package my code fails to extract the parameter set from the workflow:
rec <- recipe(mpg ~.,mtcars)
model_tune <- parsnip::boost_tree(
mode = 'regression',
trees = 1000,
tree_depth = tune(),
learn_rate = tune()
model_wflow <- workflow() %>%
add_model(model_tune) %>%
add_recipe(rec)
HP_set <- extract_parameter_set_dials(model_wflow, tree_depth(range = c(1,100)))
HP_set
The function extract_parameter_set_dials()
always prompts the following error :
Error in `mutate()`:
! Problem while computing `object = purrr::map(call_info, eval_call_info)`.
Caused by error in `.f()`:
! Error when calling gbm(): Error in terms.formula(formula, data = data) :
argument is not a valid model
Maybe this has something to do with the set_fit()
options in the engine settings but that is just a wild guess.
How can I use the gbm
engine for boost_tree()
and tune the hyperparameter with tune_bayes()
?
You were really close but there were a couple of issues:
set_model_arg(
) calls should reference the functions in the dials packagelibrary(tidymodels)
set_model_engine(model = "boost_tree",
mode = "regression",
eng = "gbm")
set_dependency("boost_tree", eng = "gbm", pkg = "gbm")
set_model_arg(
model = "boost_tree",
eng = "gbm",
parsnip = "trees",
original = "n.trees",
func = list(pkg = "dials", fun = "trees"), # <- change here
has_submodel = FALSE
)
set_model_arg(
model = "boost_tree",
eng = "gbm",
parsnip = "tree_depth",
original = "interaction.depth",
func = list(pkg = "dials", fun = "tree_depth"), # <- change here
has_submodel = FALSE
)
set_model_arg(
model = "boost_tree",
eng = "gbm",
parsnip = "learn_rate",
original = "shrinkage",
func = list(pkg = "dials", fun = "learn_rate"), # <- change here
has_submodel = FALSE
)
set_encoding(
model = "boost_tree",
eng = "gbm",
mode = "regression",
options = list(
predictor_indicators = "none",
compute_intercept = FALSE,
remove_intercept = FALSE,
allow_sparse_x = FALSE
)
)
set_fit(
model = "boost_tree",
eng = "gbm",
mode = "regression",
value = list(
interface = "formula",
# other possible values are "data.frame", "matrix"
protect = c("formula", "data"),
# nonchangeable user-arguments
func = c(pkg = "gbm", fun = "gbm"),
# call gbm::gbm()
defaults = list(
distribution = "gaussian",
n.cores = NULL,
verbose = FALSE
) # default argument changeable by user
)
)
set_pred(
model = "boost_tree",
eng = "gbm",
mode = "regression",
type = "numeric",
value = list(
pre = NULL,
post = NULL,
func = c(fun = "predict"),
args = list(
object = expr(object$fit),
newdata = expr(new_data),
n.trees = expr(object$fit$n.trees),
type = "response",
single.tree = TRUE
)
)
)
model_spec <- parsnip::boost_tree(
mode = "regression",
trees = 1000,
tree_depth = tune(),
learn_rate = tune()
) %>%
set_engine("gbm")
data(Sacramento)
model_wflow <- workflow(price ~ beds + baths + sqft, model_spec)
extract_parameter_set_dials(model_wflow, tree_depth(range = c(1, 100)))
#> Collection of 2 parameters for tuning
#>
#> identifier type object
#> tree_depth tree_depth nparam[+]
#> learn_rate learn_rate nparam[+]
tune_bayes(
model_wflow,
resamples = bootstraps(Sacramento, times = 5),
iter = 3
)
#> # Tuning results
#> # Bootstrap sampling
#> # A tibble: 20 × 5
#> splits id .metrics .notes .iter
#> <list> <chr> <list> <list> <int>
#> 1 <split [932/341]> Bootstrap1 <tibble [10 × 6]> <tibble [0 × 3]> 0
#> 2 <split [932/348]> Bootstrap2 <tibble [10 × 6]> <tibble [0 × 3]> 0
#> 3 <split [932/336]> Bootstrap3 <tibble [10 × 6]> <tibble [0 × 3]> 0
#> 4 <split [932/348]> Bootstrap4 <tibble [10 × 6]> <tibble [0 × 3]> 0
#> 5 <split [932/359]> Bootstrap5 <tibble [10 × 6]> <tibble [0 × 3]> 0
#> 6 <split [932/341]> Bootstrap1 <tibble [2 × 6]> <tibble [0 × 3]> 1
#> 7 <split [932/348]> Bootstrap2 <tibble [2 × 6]> <tibble [0 × 3]> 1
#> 8 <split [932/336]> Bootstrap3 <tibble [2 × 6]> <tibble [0 × 3]> 1
#> 9 <split [932/348]> Bootstrap4 <tibble [2 × 6]> <tibble [0 × 3]> 1
#> 10 <split [932/359]> Bootstrap5 <tibble [2 × 6]> <tibble [0 × 3]> 1
#> 11 <split [932/341]> Bootstrap1 <tibble [2 × 6]> <tibble [0 × 3]> 2
#> 12 <split [932/348]> Bootstrap2 <tibble [2 × 6]> <tibble [0 × 3]> 2
#> 13 <split [932/336]> Bootstrap3 <tibble [2 × 6]> <tibble [0 × 3]> 2
#> 14 <split [932/348]> Bootstrap4 <tibble [2 × 6]> <tibble [0 × 3]> 2
#> 15 <split [932/359]> Bootstrap5 <tibble [2 × 6]> <tibble [0 × 3]> 2
#> 16 <split [932/341]> Bootstrap1 <tibble [2 × 6]> <tibble [0 × 3]> 3
#> 17 <split [932/348]> Bootstrap2 <tibble [2 × 6]> <tibble [0 × 3]> 3
#> 18 <split [932/336]> Bootstrap3 <tibble [2 × 6]> <tibble [0 × 3]> 3
#> 19 <split [932/348]> Bootstrap4 <tibble [2 × 6]> <tibble [0 × 3]> 3
#> 20 <split [932/359]> Bootstrap5 <tibble [2 × 6]> <tibble [0 × 3]> 3
Created on 2022-05-03 by the reprex package (v2.0.1)