I want to use mlr3 for cross-fitting of nuisance parameters in a semi-parametric model such as TMLE or AIPW. The cross-fitting procedure is similar to k-fold cross-validation; split the data into K sets of somewhat equal size, obtain predictions for each group using the data in the remaining groups for model training. However, with cross-fitting, I'm not interested in model evaluation. Instead, I need to reuse the K models to produce out-of-sample predictions to relax certain assumptions necessary for valid statistical inference with machine learning estimators.
I'd like to use resample
from mlr3 for this.
require(mlr3verse)
# Create some data
set.seed(5434)
n <- 250
W <- matrix(rnorm(n*3), ncol=3)
A <- rbinom(n,1, 1/(1+exp(-(.2*W[,1] - .1*W[,2] + .4*W[,3]))))
Y <- A + 2*W[,1] + W[,3] + W[,2]^2 + rnorm(n)
dat <- data.frame(W, A, Y)
# Creating a Task with 2 pre-defined folds
K <- 2
folds <- sample(rep(1:K, length.out = n),
size = n,
replace = FALSE)
dat[, "fold_id"] <- folds
task <- as_task_regr(dat, "Y", "foo_task")
task$col_roles$group <- "fold_id"
task$col_roles$feature <- setdiff(task$col_roles$feature, "fold_id")
# Create a light gbm learner object
learn_gbm <- lrn("regr.lightgbm")
# Repeatedely train the learner K times and store the models
cv <- rsmp("cv", folds = K)
rr <- resample(task, learn_gbm, cv, store_models = TRUE)
From here, I'd like to use the stored models to predict on modified versions of dat
(i.e., A is set to 1) of the K test-sets:
# Creating a copy of the dat where A is always 1
# Want to obtain out-of-sample predictions of Y on this data, dat_1
dat_1 <- dat
dat_1$A <- 1
# Using the first fold as an example
predict(rr$learners[[1]], newdata = dat_1[rr$resampling$test_set(1), ])
It seems like I can't use the stored models to predict on new data and I get this error:
Error: No task stored, and no task provided
How can I get these predictions with resample()
?
Platform: aarch64-apple-darwin20 (64-bit)
Running under: macOS Monterey 12.4
Matrix products: default
LAPACK: /Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics grDevices utils datasets methods base
other attached packages:
[1] mlr3verse_0.2.5 mlr3_0.14.0
loaded via a namespace (and not attached):
[1] tidyselect_1.1.2 clusterCrit_1.2.8 purrr_0.3.4
[4] listenv_0.8.0 lattice_0.20-45 mlr3cluster_0.1.4
[7] colorspace_2.0-3 vctrs_0.4.1 generics_0.1.3
[10] bbotk_0.5.4 paradox_0.10.0 utf8_1.2.2
[13] rlang_1.0.4 pillar_1.8.0 glue_1.6.2
[16] withr_2.5.0 DBI_1.1.3 palmerpenguins_0.1.1
[19] uuid_1.1-0 prompt_1.0.1 mlr3fselect_0.7.2
[22] lifecycle_1.0.1 mlr3learners_0.5.4 munsell_0.5.0
[25] gtable_0.3.0 progressr_0.10.1 future_1.27.0
[28] codetools_0.2-18 mlr3data_0.6.1 parallel_4.2.1
[31] fansi_1.0.3 mlr3tuningspaces_0.3.0 scales_1.2.0
[34] backports_1.4.1 checkmate_2.1.0 mlr3filters_0.5.0
[37] mlr3viz_0.5.10 mlr3tuning_0.14.0 jsonlite_1.8.0
[40] lightgbm_3.3.2 parallelly_1.32.1 ggplot2_3.3.6
[43] digest_0.6.29 dplyr_1.0.9 mlr3extralearners_0.5.46-9000
[46] grid_4.2.1 clue_0.3-61 cli_3.3.0
[49] tools_4.2.1 magrittr_2.0.3 tibble_3.1.7
[52] cluster_2.1.3 mlr3misc_0.10.0 future.apply_1.9.0
[55] crayon_1.5.1 pkgconfig_2.0.3 Matrix_1.4-1
[58] ellipsis_0.3.2 data.table_1.14.2 mlr3pipelines_0.4.1
[61] assertthat_0.2.1 rstudioapi_0.13 lgr_0.4.3
[64] R6_2.5.1 globals_0.16.1 compiler_4.2.1
Reading https://mlr3.mlr-org.com/reference/Learner.html#method-predict-newdata- more closely, I found
If the learner has been fitted via resample() or benchmark(), you need to pass the corresponding task stored in the ResampleResult or BenchmarkResult, respectively.
Trying,
rr$learners[[1]]$predict_newdata(dat_1[rr$resampling$test_set(1), ], task = rr$task)
produced,
<PredictionRegr> for 125 observations:
row_ids truth response
1 6.360351 5.4017536
2 1.104371 2.1685862
3 1.445641 1.6610209
---
123 5.533508 2.0075193
124 1.927375 4.5808117
125 -1.634550 0.7129312
Which looks like what I needed.