mlr3

How to extract predictions, say of survival probability, of the outer loop sample during nested cv (mlr3 benchmark)?


I would appreciate some guidance on extraction of distributional predictions (eg event probability at a particular time point), linear predictors from the outer loop of a nested cv, and extraction of the baseline hazard from the inner loop of a nested cv (where the model is developed) during mlr3 benchmarking procedures to calculate calibration indexes.

Using the lung dataset as an example:

rm(list = (ls(all=T)))
library(reprex)
library(mlr3)
library(mlr3learners)
library(mlr3extralearners)
library(mlr3pipelines)
library(mlr3tuning)
#> Loading required package: paradox
library(mlr3proba)
library(data.table)

# set up data
task_lung = tsk('lung')
d = task_lung$data()
d$time = ceiling(d$time/30.44)
task_lung = as_task_surv(d, time = 'time', event = 'status', id = 'lung')
po_encode = po('encode', method = 'treatment')
po_impute = po('imputelearner', lrn('regr.rpart'))
pre = po_encode %>>% po_impute
task = pre$train(task_lung)[[1]]


# learners
cph=lrn("surv.coxph")
# get baseline hazard estimates
comp.cph = as_learner(ppl(  
  "distrcompositor",
  learner = cph,
  estimator = "kaplan",
  form = "ph"
))



# Benchmark above 3 (in outer 4 folds)
set.seed(123)
BM1 = benchmark(benchmark_grid(task,
                      list(cph, comp.cph),
                      rsmp('cv', folds=4)),
                store_models =T)
#> INFO  [13:00:22.564] [mlr3] Running benchmark with 8 resampling iterations
#> INFO  [13:00:22.649] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 1/4)
#> INFO  [13:00:22.716] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 2/4)
#> INFO  [13:00:22.766] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 3/4)
#> INFO  [13:00:22.825] [mlr3] Applying learner 'surv.coxph' on task 'lung' (iter 4/4)
#> INFO  [13:00:22.873] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 1/4)
#> INFO  [13:00:23.029] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 2/4)
#> INFO  [13:00:23.225] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 3/4)
#> INFO  [13:00:23.365] [mlr3] Applying learner 'distrcompositor.kaplan.surv.coxph.distrcompose' on task 'lung' (iter 4/4)
#> INFO  [13:00:23.510] [mlr3] Finished benchmark

# extract (inner) models
mdl=mlr3misc::map(as.data.table(BM1)$learner, "model")

# check that the LP is for internal loop
# the following has 171 obs as expected from inner loops
mdl[[1]][["linear.predictors"]]
#>   [1]  1.18499215  0.31858695  0.99802300  0.58241827  0.15774088 -0.12313273
#>   [7]  0.27176230  0.31255547  0.87977282  1.66127703  1.27052489  0.32452650
#>  [13] -0.39147459  0.18128928  0.40221932  0.32179235  1.29488433  0.22421192
#>  [19]  0.36696516  0.02351139  0.03435003  0.30268183  0.46156831  0.67854111
#>  [25] -0.38466200  0.18029125 -0.30794694  1.81985726  0.72143148  0.73993999
#>  [31] -0.19820495 -0.61578221  1.41832763  0.61361444  0.26284730  1.59711534
#>  [37]  0.08964261 -0.39070785  0.34200863  0.63126957  1.02630566  0.01239987
#>  [43]  0.23198859 -0.26151613 -0.39089134  0.19280407 -0.87206586 -0.62197655
#>  [49]  0.88013192  0.08972523 -0.74319591 -1.03042024 -0.27852335  0.08151983
#>  [55]  1.71881259 -0.67763599  1.11294971  0.88264769  0.10990293 -0.02842130
#>  [61]  0.37947992  0.49182045  0.77795647  0.71053557  0.65215179  0.27799555
#>  [67] -0.37213111  1.88280007  1.19325541 -0.12267881  1.15952023 -0.94928614
#>  [73]  0.86114411  0.31579099  1.91908641  0.77277223 -0.04938800  0.39655017
#>  [79] -0.02114651  1.77783288  1.01903263  0.67158262  0.49602602 -0.53871485
#>  [85]  1.13274466  0.24812544  1.61907952 -0.01141713  0.85803147  0.94607458
#>  [91] -0.41149162  0.35978972  1.53106530  1.31347748  0.12968634  0.95718654
#>  [97]  1.21293718  0.90307591 -0.10667084  0.87119681  0.94642450 -0.40357218
#> [103]  0.21266674  0.54770469  0.60351288  0.61024865 -1.15985267 -0.60745992
#> [109]  0.64926122  0.94255144  0.10625227 -0.17035191  0.35624501  0.78856855
#> [115]  0.82062314  0.64819105  0.29187039  0.83125774  1.40013302  0.64771200
#> [121]  1.15793176  0.62013293  0.13332034  0.11379916  0.53277392  0.67303587
#> [127]  0.07299108 -0.22845207 -0.37492090 -0.22763493  0.18024949 -0.45301382
#> [133] -0.11000587  0.10124783  0.11111948  0.43462859  1.19243908  0.87027351
#> [139]  1.32904450  0.14259790  0.39245370  0.62142923 -0.33833996  0.48248721
#> [145]  0.94495937  1.09431650  0.78709427  0.18538695 -0.26797215 -0.40705985
#> [151]  1.49430202  1.47283382 -1.12637695 -0.16080892 -0.35788332  0.57572753
#> [157] -0.85507993 -0.05270265 -0.11091513  0.60276879 -0.35633053  0.47495284
#> [163]  0.81178544  0.73027874 -0.43725719 -0.22938202 -0.28038346  0.54849467
#> [169]  0.06449779  1.38581779 -0.29679293

# extract baseline hazard 
# Assuming that these are from inner loops (which is required for distributional prediction)
mdl[[5]][["distrcompositor.kaplan"]][["model"]][["std.chaz"]]
#>  [1] 0.01547223 0.02062259 0.02649866 0.03237961 0.03735763 0.04532140
#>  [7] 0.05320536 0.05999788 0.06527124 0.07576031 0.08515480 0.10010706
#> [13] 0.10877590 0.11161182 0.12888880 0.13672413 0.13672413 0.14575496
#> [19] 0.15721555 0.16497494 0.17407800 0.19293410 0.20412743 0.23869255
#> [25] 0.27111447 0.32365577 0.36404785 0.36404785 0.49360100 0.49360100


# Extract outer learners based on a previous question: "mlr3, benchmarking and nested resampling: how to extract a tuned model from a benchmark object to calculate feature importance, the suggested code to extract learners fitted in the outer loop" @ https://github.com/mlr-org/mlr3/issues/601

data = as.data.table(BM1)
outer_learners = mlr3misc::map(data$learner, "learner"); outer_learners  #outer_learners is null
#> [[1]]
#> NULL
#> 
#> [[2]]
#> NULL
#> 
#> [[3]]
#> NULL
#> 
#> [[4]]
#> NULL
#> 
#> [[5]]
#> NULL
#> 
#> [[6]]
#> NULL
#> 
#> [[7]]
#> NULL
#> 
#> [[8]]
#> NULL


# Based on the following https://github.com/mlr-org/mlr3/issues/601, it may be possible to save the best model from inner cv and apply this model for prediction on a new/outer dataset (This is not working too)
measures = list (
  msr('surv.cindex', id='cindex'),
  msr('surv.graf', id='brier'))
best_mdl <- BM1$score(measures)[learner_id=='encode.scale.surv.glmnet.tuned',][cindex==max(cindex)]$learners
#> Warning in max(cindex): no non-missing arguments to max; returning -Inf

Created on 2024-05-03 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.1 (2021-08-10)
#>  os       macOS Big Sur 10.16         
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Australia/Adelaide          
#>  date     2024-05-03                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package           * version     date       lib
#>  backports           1.4.1       2021-12-13 [1]
#>  bbotk               0.7.3.9000  2023-11-22 [1]
#>  checkmate           2.3.1       2023-12-04 [1]
#>  cli                 3.6.2       2023-12-11 [1]
#>  codetools           0.2-18      2020-11-04 [2]
#>  colorspace          2.1-0       2023-01-23 [1]
#>  crayon              1.4.1       2021-02-08 [2]
#>  data.table        * 1.15.4      2024-03-30 [1]
#>  dictionar6          0.1.3       2021-09-13 [1]
#>  digest              0.6.35      2024-03-11 [1]
#>  distr6              1.8.4       2024-05-02 [1]
#>  dplyr               1.1.3       2023-09-03 [1]
#>  evaluate            0.23        2023-11-01 [1]
#>  fansi               1.0.6       2023-12-08 [1]
#>  fastmap             1.1.0       2021-01-25 [2]
#>  fs                  1.5.0       2020-07-31 [2]
#>  future              1.33.2      2024-03-26 [1]
#>  future.apply        1.11.2      2024-03-28 [1]
#>  generics            0.1.3       2022-07-05 [1]
#>  ggplot2             3.5.1       2024-04-23 [1]
#>  globals             0.16.3      2024-03-08 [1]
#>  glue                1.7.0       2024-01-09 [1]
#>  gtable              0.3.5       2024-04-22 [1]
#>  highr               0.9         2021-04-16 [2]
#>  htmltools           0.5.6       2023-08-10 [1]
#>  knitr               1.33        2021-04-24 [2]
#>  lattice             0.20-44     2021-05-02 [2]
#>  lgr                 0.4.4       2022-09-05 [1]
#>  lifecycle           1.0.4       2023-11-07 [1]
#>  listenv             0.9.1       2024-01-29 [1]
#>  magrittr            2.0.3       2022-03-30 [1]
#>  Matrix              1.3-4       2021-06-01 [2]
#>  mlr3              * 0.19.0      2024-04-24 [1]
#>  mlr3extralearners * 0.7.1       2023-11-24 [1]
#>  mlr3learners      * 0.5.7.9000  2023-11-24 [1]
#>  mlr3misc            0.15.0      2024-04-10 [1]
#>  mlr3pipelines     * 0.5.0-9000  2023-11-22 [1]
#>  mlr3proba         * 0.6.1       2024-05-02 [1]
#>  mlr3tuning        * 0.19.1.9000 2023-11-22 [1]
#>  mlr3viz             0.8.0       2024-03-05 [1]
#>  munsell             0.5.1       2024-04-01 [1]
#>  ooplah              0.2.0       2022-01-21 [1]
#>  palmerpenguins      0.1.1       2022-08-15 [1]
#>  paradox           * 0.11.1-9000 2023-11-22 [1]
#>  parallelly          1.37.1      2024-02-29 [1]
#>  param6              0.2.4       2023-11-22 [1]
#>  pillar              1.9.0       2023-03-22 [1]
#>  pkgconfig           2.0.3       2019-09-22 [2]
#>  R6                  2.5.1       2021-08-19 [1]
#>  Rcpp                1.0.12      2024-01-09 [1]
#>  reprex            * 2.0.1       2021-08-05 [1]
#>  RhpcBLASctl         0.23-42     2023-02-11 [1]
#>  rlang               1.1.3       2024-01-10 [1]
#>  rmarkdown           2.10        2021-08-06 [2]
#>  rpart               4.1-15      2019-04-12 [2]
#>  rstudioapi          0.15.0      2023-07-07 [1]
#>  scales              1.3.0       2023-11-28 [1]
#>  sessioninfo         1.1.1       2018-11-05 [2]
#>  set6                0.2.6       2023-11-22 [1]
#>  stringi             1.7.3       2021-07-16 [2]
#>  stringr             1.5.0       2022-12-02 [1]
#>  survival            3.5-7       2023-08-14 [1]
#>  survivalmodels      0.1.191     2024-03-19 [1]
#>  tibble              3.2.1       2023-03-20 [1]
#>  tidyselect          1.2.0       2022-10-10 [1]
#>  utf8                1.2.4       2023-10-22 [1]
#>  uuid                1.2-0       2024-01-14 [1]
#>  vctrs               0.6.5       2023-12-01 [1]
#>  withr               3.0.0       2024-01-16 [1]
#>  xfun                0.25        2021-08-06 [2]
#>  yaml                2.2.1       2020-02-01 [2]
#>  source                                    
#>  CRAN (R 4.1.0)                            
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/distr6@a7c01f7)             
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3extralearners@6e2af9e)
#>  Github (mlr-org/mlr3learners@86f19eb)     
#>  CRAN (R 4.1.1)                            
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  https://mlr-org.r-universe.dev (R 4.1.1)  
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/param6@0fa3577)             
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  Github (xoopR/set6@a901255)               
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#> 
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library

I would very much appreciate any advice/suggestions on how to extract the relevant results on the inner loop and apply these results for prediction tasks on the outer loop samples during nested-cv benchmarking procedures in mlr3.


Solution

  • Short answer

    ... to the original title: you already have the predictions of the outer-loop, it's in the data$prediction slot.

    Details on the code/Suggestions

    learners
    cph=lrn("surv.coxph")
    # get baseline hazard estimates
    

    Notes:

    1. with distrcompositor you compose S(t) (survival prediction), you are not getting the baseline S_0(t). Using a CoxPH model, you don't need to compose an S(t), the one used by the survival package by default (Breslow) is good enough. I think later you need just the trained Kaplan fit (survfit), so you could just use lrn("surv.kaplan").
    2. argument overwrite is FALSE by default so you are not adding anything by using the distrcompositor here. Just making your life more difficult to get the learner that's inside the PipeOp afterwards (if you need that).
    comp.cph = as_learner(ppl(  
      "distrcompositor",
      learner = cph,
      estimator = "kaplan",
      form = "ph"
    ))
    comp.cph # see that `distrcompose.overwrite=FALSE`, so SAME S(t) as `cph` learner
    
    lasso = as_learner(
      po("encode") %>>% 
      po('scale') %>>% # standardize = TRUE by default in glmnet so this is re-done
        # Note: good to write the arguments, eg `tuner = ...`
      auto_tuner(
        tuner = tnr("grid_search", resolution = 10, batch_size =10),
        learner = lrn("surv.glmnet", alpha=1, s = to_tune(0.005, 1)),
        resampling = rsmp("cv", folds=2),
        measure = msr(c("surv.cindex")),
        store_models = TRUE,
        terminator = trm("stagnation", iters=50, threshold=0.01))
      )
    

    I assume the baseline hazards extracted above were also based on the inner sample (which is required for outer sample (distributional) prediction for cph/lasso)

    YES, since:

    mdl[[9]]$distrcompositor.kaplan$model$n (171 = #samples in 3 out of 4 folds)
    

    the suggested code to extract learners fitted in the outer loop is not working:

    It was a different example there, I get a different error when running it:

    data = as.data.table(BM1)
    outer_learners = mlr3misc::map(data$learner, "learner")
    Error in `[[.R6`(X[[i]], ...) :
    R6 class LearnerSurvCoxPH/LearnerSurv/Learner/R6 does not have slot 'learner'!
    

    But look, data$learner is the list of learners you are asking for, i.e. the 'tuned'/'trained' version of learners (on the 3 out of 4 folds), ready to be used for making predictions.

    ...it may be possible to save the best model from inner cv (is the
    data$learner) and apply this model for prediction on a new/outer dataset:

    measures = list(msr('surv.cindex', id='cindex'), msr('surv.graf', id='brier'))
    BM1$score(measures) # this works for me, no problem, update packages will fix this
    

    I would very much appreciate any hints/suggestions on how to extract the relevant results on the inner loop and apply these results for prediction tasks on the outer loop samples from nested-cv benchmarking procedures in mlr3.

    The outer-loop sample predictions are exactly the following (#samples = 57 = 228/4, ie 1 outer-fold):

    data$prediction
    ...
    
    [[12]]
    <PredictionSurv> for 57 observations:
        row_ids time status       crank          lp     distr
             18   24   TRUE  0.82866010  0.82866010 <list[1]>
             21   10   TRUE  0.72062419  0.72062419 <list[1]>
             26   18   TRUE  0.33464409  0.33464409 <list[1]>
    ---                                                      
            222    7  FALSE  0.07149053  0.07149053 <list[1]>
            224    7  FALSE  1.65191498  1.65191498 <list[1]>
            228    6  FALSE -0.40258256 -0.40258256 <list[1]>