rtidymodelslasso-regression

How can I pass an extra variable to a tidymodels fit function?


I am writing a tidymodels engine to fit the joint LASSO described here and implemented in the fuser package. This is a LASSO model for regression that allows partial sharing of information between groups, in this case tissues in a biomedical RNAseq experiment.

An important difference between this and the regular LASSO is that while it also takes as arguments a matrix of predictors X and a vector of outcomes Y, it also takes a vector of group indicators, which in my case would be the tissue that each observation belongs to.

I would like to pass in a column name or role to the groups argument of the fit function, so that in a resampling/cross-validation loop, the groups argument corresponds to the grouping of the subset of data used in the resampling loop.

I can't find any information in the parsnip docs describing how to specify extra variables/selections/etc in the fit function. If I try to add it to the engine definition as I've done below, then when I try to use the engine in a normal tidymodels pipeline, I get an error that groups is missing (expected, given I haven't specified how to pass it to the fitting function).

library("tidymodels")
## the content of src/tidymodels-utils.R is at the bottom of this question
## beginning with set_new_model
source(here("src/tidymodels-utils.R"))


nfeats <- 100
nsamples <- 120
ngroups <- 2
group <- sample(letters[1:ngroups], nsamples, replace=TRUE)
predictors <- matrix(rnorm(nfeats*nsamples), nrow = nsamples, ncol = nfeats,
    dimnames = list(paste("Sample", 1:nsamples), paste("Feature", 1:nfeats))
)
outcome <- rnorm(nsamples)

## tidymodels wants a dataframe input
input <- data.frame(
    predictors,
    group = group,
    outcome = nsamples
)

norm_recipe <- recipe(input) %>%
    update_role(matches("Feature"), new_role = "predictor") %>%
    update_role(outcome, new_role = "outcome") %>%
    update_role(group, new_role = "group") %>%
    ## some steps to center/scale/select predictors and outcome omitted here
    prep()

glmnet_model <- joint_lasso(penalty = tune(), fusion = tune(),
        groups = matches("group")) %>% 
    set_engine("fusedLassoProximal")

workflow <- workflow() %>%
    add_recipe(norm_recipe) %>%
    add_model(glmnet_model)

folds <- vfold_cv(input, v = nfolds, repeats = nrepeats)

glmn_set <- parameters(
    penalty(),
    fusion()
)
glmn_grid <- grid_regular(glmn_set, levels = 3)
ctrl <- control_grid(save_pred = TRUE, verbose = TRUE)

results <- workflow %>%
    tune_grid(
        resamples = folds,
        metrics = metric_set(rmse),
        grid = glmn_grid,
        control = ctrl
    )
# Error in `mutate()`:
# ℹ In argument: `object = purrr::map(call_info, eval_call_info)`.
# Caused by error in `purrr::map()`:
# ℹ In index: 1.
# Caused by error in `.f()`:
# ! Error when calling fusedLassoProximal(): Error in unique(groups) : argument "groups" is missing, with no default
# Run `rlang::last_trace()` to see where the error occurred.

Any ideas where this will fit? There's a lot of boilerplate but my journeys through the various ? pages hasn't shown anything obvious. Worth noting that the prediction method will also need this group information about the test data in each resample, because the return value of fuser::fusedProximalLasso is simply a matrix of coefficients, one column for each unique value of group.

The whole set of function calls to define the engine (and the fusion() penalty param) is as follows:

set_new_model("joint_lasso")
set_model_mode(model = "joint_lasso", mode = "regression")
set_model_engine(
    "joint_lasso", 
    mode = "regression", 
    eng = "fusedLassoProximal"
)
set_dependency("joint_lasso",
    eng = "fusedLassoProximal", pkg = "fuser"
)


set_model_arg(
    model = "joint_lasso",
    eng = "fusedLassoProximal",
    parsnip = "penalty",
    original = "lambda",
    func = list(pkg = "fuser", fun = "fusedLassoProximal"),
    has_submodel = FALSE
)

set_model_arg(
    model = "joint_lasso",
    eng = "fusedLassoProximal",
    parsnip = "groups",
    original = "groups",
    func = list(pkg = "fuser", fun = "fusedLassoProximal"),
    has_submodel = FALSE
)

joint_lasso <- function(
        mode = "regression", 
        penalty = NULL,
        fusion = NULL,
        groups = NULL) {

    if (mode  != "regression") {
        rlang::abort("`mode` should be 'regression'")
    }
    
    args <- list(
        groups = rlang::enquo(groups),
        penalty = rlang::enquo(penalty),
        fusion = rlang::enquo(fusion)
    )
    
    # Save some empty slots for future parts of the specification
    new_model_spec(
        "joint_lasso",
        args = args,
        eng_args = NULL,
        mode = mode,
        method = NULL,
        engine = NULL
    )
}

set_fit(
    model = "joint_lasso",
    eng = "fusedLassoProximal",
    mode = "regression",
    value = list(
        interface = "matrix",
        protect = c("X", "Y", "groups"),
        func = c(pkg = "fuser", fun = "fusedLassoProximal"),
        defaults = list()
    )
)

set_encoding(
    model = "joint_lasso",
    eng = "fusedLassoProximal",
    mode = "regression",
    options = list(
        predictor_indicators = "traditional",
        compute_intercept = TRUE,
        remove_intercept = TRUE,
        allow_sparse_x = FALSE
    )
)

predict_joint_lasso <- function(object, new_data=NULL) {
    ## some code here to use the new data
}

pred_info <- list(
    pre = NULL,
    post = NULL,
    func = c(fun = "predict_joint_lasso"),
    args =
        list(
            object = quote(object$fit),
            new_data = quote(new_data)
        )
)

set_pred(
    model = "joint_lasso",
    eng = "fusedLassoProximal",
    mode = "regression",
    type = "numeric",
    value = pred_info
)

show_model_info("joint_lasso")

fusion <- function(range=c(0, 2), trans = NULL) {
    new_quant_param(
        type = "double",
        range = range,
        inclusive = c(TRUE, TRUE),
        trans = trans,
        label = c(fusion = "Fusion parameter"),
        finalize = NULL
    )
}

Solution

  • I think that you should write a wrapper function to provide a different interface for that function. Your point about resampling is spot-on and you need a way to point fuser's functions to specific columns in the data.

    We had similar issues with the gee function, which also had an argument for a length n vector of groups. In that case, we used the little documented "specials" feature of base R formulas.

    Here is a gist that has an implementation if you want to use it. Basically, you would require a formula to be used that has a groups() function in it. The variable in that function defines the groups. It's all there.

    One other thing (addressed in the gist): the G argument is data-dependent. If a user uses a recipe or some tool that might reduce features, then the dimensions might be off. There is a function that will take the G argument as an expression and then evaluate it with the number of columns that are actually in the data. There are other (and probably better) approaches, but I thought I'd point that out.

    Feel free to put in a parsnip R if there is more that we can help with.