rhyperparameterstidymodelsgbmr-parsnip

extract_parameter_set_dials() fails for self-written gbm engine for boost_tree()


I followed this vignette of the tidymodels package to make the gbm an engine for the boost_tree() function.

I came up with the following for regression tasks:

library(tidyverse)
library(tidymodels)
library(workflow)
library(tune)
install.packages(gbm) #just install do not load
{
    set_model_engine(
      model = "boost_tree", 
      mode  = "regression", 
      eng  = "gbm"
    )
    set_dependency("boost_tree", eng = "gbm", pkg = "gbm")
    set_model_arg(
      model        = "boost_tree", 
      eng          = "gbm", 
      parsnip      = "trees", 
      original     = "n.trees",
      func         = list(pkg = "gbm", fun = "gbm"),
      has_submodel = TRUE
    )
    set_model_arg(
      model        = "boost_tree", 
      eng          = "gbm", 
      parsnip      = "tree_depth", 
      original     = "interaction.depth", 
      func         = list(pkg = "gbm", fun = "gbm"), 
      has_submodel = TRUE
    )
    set_model_arg(
      model        = "boost_tree", 
      eng          = "gbm", 
      parsnip      = "learn_rate", 
      original     = "shrinkage", 
      func         = list(pkg = "gbm", fun = "gbm"), 
      has_submodel = TRUE
    )
    
    set_encoding(
      model = "boost_tree",
      eng = "gbm",
      mode = "regression",
      options = list(
        predictor_indicators = "none",
        compute_intercept = FALSE,
        remove_intercept = FALSE,
        allow_sparse_x = FALSE
      )
    )
    
    gbm <- function(mode = "regression", trees = NULL, 
                    tree_depth = NULL, learn_rate = NULL, cv.folds = 1) {
      # make sure mode is regression
      if(mode != "regression") {
        stop("`mode` should be 'regression'", call. = FALSE)
      }
      
      # capture argument in quosures
      args <- list(
        trees      = rlang::enquo(trees), 
        tree_depth = rlang::enquo(tree_depth), 
        learn_rate = rlang::enquo(learn_rate)
      )
      
      # empty slots for future parts of specification
      out <- list(args = args, eng_args = NULL, 
                  mode = mode, method = NULL, engine = NULL)
      
      # set class in correct order
      class(out) <- make_classes("gbm")
      out
    }
    
    set_fit(
      model = "boost_tree", 
      eng   = "gbm", 
      mode  = "regression", 
      value = list(
        interface = "formula", # other possible values are "data.frame", "matrix"
        protect = c("formula", "data"), # nonchangeable user-arguments
        func = c(pkg = "gbm", fun = "gbm"), # call gbm::gbm()
        defaults = list(
          distribution = "gaussian", 
          n.cores = NULL, 
          verbose = FALSE
        ) # default argument changeable by user
      )
    )
    
    set_pred(
      model = "boost_tree", 
      eng   = "gbm", 
      mode  = "regression", 
      type  = "numeric", 
      value = list(
        pre = NULL, 
        post = NULL, 
        func = c(fun = "predict"), 
        args = list(
          object = expr(object$fit), 
          newdata = expr(new_data), 
          n.trees = expr(object$fit$n.trees),
          type = "response", 
          single.tree = TRUE
        )
      )
    )
    

}

But if I try to use this engine to tune the hyperparameters using tune_bayes() from the parsnip package my code fails to extract the parameter set from the workflow:

rec <- recipe(mpg ~.,mtcars)

model_tune <- parsnip::boost_tree(
        mode = 'regression',
        trees = 1000,
        tree_depth = tune(),
        learn_rate = tune()

model_wflow <- workflow() %>%
  add_model(model_tune) %>%
  add_recipe(rec)


HP_set <- extract_parameter_set_dials(model_wflow, tree_depth(range = c(1,100)))
HP_set

The function extract_parameter_set_dials() always prompts the following error :

Error in `mutate()`:
! Problem while computing `object = purrr::map(call_info, eval_call_info)`.
Caused by error in `.f()`:
! Error when calling gbm(): Error in terms.formula(formula, data = data) : 
  argument is not a valid model

Maybe this has something to do with the set_fit() options in the engine settings but that is just a wild guess.

How can I use the gbm engine for boost_tree() and tune the hyperparameter with tune_bayes()?


Solution

  • You were really close but there were a couple of issues:

    library(tidymodels)
    
    set_model_engine(model = "boost_tree",
                     mode  = "regression",
                     eng  = "gbm")
    set_dependency("boost_tree", eng = "gbm", pkg = "gbm")
    set_model_arg(
        model        = "boost_tree",
        eng          = "gbm",
        parsnip      = "trees",
        original     = "n.trees",
        func         = list(pkg = "dials", fun = "trees"),  # <- change here
        has_submodel = FALSE
    )
    set_model_arg(
        model        = "boost_tree",
        eng          = "gbm",
        parsnip      = "tree_depth",
        original     = "interaction.depth",
        func         = list(pkg = "dials", fun = "tree_depth"), # <- change here
        has_submodel = FALSE
    )
    set_model_arg(
        model        = "boost_tree",
        eng          = "gbm",
        parsnip      = "learn_rate",
        original     = "shrinkage",
        func         = list(pkg = "dials", fun = "learn_rate"), # <- change here
        has_submodel = FALSE
    )
    
    set_encoding(
        model = "boost_tree",
        eng = "gbm",
        mode = "regression",
        options = list(
            predictor_indicators = "none",
            compute_intercept = FALSE,
            remove_intercept = FALSE,
            allow_sparse_x = FALSE
        )
    )
    
    set_fit(
        model = "boost_tree",
        eng   = "gbm",
        mode  = "regression",
        value = list(
            interface = "formula",
            # other possible values are "data.frame", "matrix"
            protect = c("formula", "data"),
            # nonchangeable user-arguments
            func = c(pkg = "gbm", fun = "gbm"),
            # call gbm::gbm()
            defaults = list(
                distribution = "gaussian",
                n.cores = NULL,
                verbose = FALSE
            ) # default argument changeable by user
        )
    )
    
    set_pred(
        model = "boost_tree",
        eng   = "gbm",
        mode  = "regression",
        type  = "numeric",
        value = list(
            pre = NULL,
            post = NULL,
            func = c(fun = "predict"),
            args = list(
                object = expr(object$fit),
                newdata = expr(new_data),
                n.trees = expr(object$fit$n.trees),
                type = "response",
                single.tree = TRUE
            )
        )
    )
    
    
    model_spec <- parsnip::boost_tree(
        mode = "regression",
        trees = 1000,
        tree_depth = tune(),
        learn_rate = tune()
    ) %>%
        set_engine("gbm")
    
    data(Sacramento)
    
    model_wflow <- workflow(price ~ beds + baths + sqft, model_spec) 
    extract_parameter_set_dials(model_wflow, tree_depth(range = c(1, 100)))
    #> Collection of 2 parameters for tuning
    #> 
    #>  identifier       type    object
    #>  tree_depth tree_depth nparam[+]
    #>  learn_rate learn_rate nparam[+]
    
    tune_bayes(
        model_wflow,
        resamples = bootstraps(Sacramento, times = 5),
        iter = 3
    )
    #> # Tuning results
    #> # Bootstrap sampling 
    #> # A tibble: 20 × 5
    #>    splits            id         .metrics          .notes           .iter
    #>    <list>            <chr>      <list>            <list>           <int>
    #>  1 <split [932/341]> Bootstrap1 <tibble [10 × 6]> <tibble [0 × 3]>     0
    #>  2 <split [932/348]> Bootstrap2 <tibble [10 × 6]> <tibble [0 × 3]>     0
    #>  3 <split [932/336]> Bootstrap3 <tibble [10 × 6]> <tibble [0 × 3]>     0
    #>  4 <split [932/348]> Bootstrap4 <tibble [10 × 6]> <tibble [0 × 3]>     0
    #>  5 <split [932/359]> Bootstrap5 <tibble [10 × 6]> <tibble [0 × 3]>     0
    #>  6 <split [932/341]> Bootstrap1 <tibble [2 × 6]>  <tibble [0 × 3]>     1
    #>  7 <split [932/348]> Bootstrap2 <tibble [2 × 6]>  <tibble [0 × 3]>     1
    #>  8 <split [932/336]> Bootstrap3 <tibble [2 × 6]>  <tibble [0 × 3]>     1
    #>  9 <split [932/348]> Bootstrap4 <tibble [2 × 6]>  <tibble [0 × 3]>     1
    #> 10 <split [932/359]> Bootstrap5 <tibble [2 × 6]>  <tibble [0 × 3]>     1
    #> 11 <split [932/341]> Bootstrap1 <tibble [2 × 6]>  <tibble [0 × 3]>     2
    #> 12 <split [932/348]> Bootstrap2 <tibble [2 × 6]>  <tibble [0 × 3]>     2
    #> 13 <split [932/336]> Bootstrap3 <tibble [2 × 6]>  <tibble [0 × 3]>     2
    #> 14 <split [932/348]> Bootstrap4 <tibble [2 × 6]>  <tibble [0 × 3]>     2
    #> 15 <split [932/359]> Bootstrap5 <tibble [2 × 6]>  <tibble [0 × 3]>     2
    #> 16 <split [932/341]> Bootstrap1 <tibble [2 × 6]>  <tibble [0 × 3]>     3
    #> 17 <split [932/348]> Bootstrap2 <tibble [2 × 6]>  <tibble [0 × 3]>     3
    #> 18 <split [932/336]> Bootstrap3 <tibble [2 × 6]>  <tibble [0 × 3]>     3
    #> 19 <split [932/348]> Bootstrap4 <tibble [2 × 6]>  <tibble [0 × 3]>     3
    #> 20 <split [932/359]> Bootstrap5 <tibble [2 × 6]>  <tibble [0 × 3]>     3
    

    Created on 2022-05-03 by the reprex package (v2.0.1)