rmlr3

mlr3: Error Accessing Model Coefficients in lrn("surv.penalized")


I encountered an error when trying to access model coefficients using mlr3. Can you help me understand how to do this properly?

library(haven)
library(tidyverse)
library(survival)
library(penalized)
#> Welcome to penalized. For extended examples, see vignette("penalized").
library(mlr3)
library(distr6)
#> 
#> Attaching package: 'distr6'
#> The following object is masked from 'package:stats':
#> 
#>     qqplot
#> The following object is masked from 'package:base':
#> 
#>     truncate
library(mlr3verse)
library(mlr3proba)
library(mlr3learners)
library(mlr3pipelines)
library(mlr3extralearners)



data <- read.csv("C:/Users/Click_32235414/Documents/liver30.csv")

#convert data to tibble
tb <- as_tibble(data)

# delete variable with a lot of missing
tb <- subset(tb, select = -c(alcohol,smoking,micro_vesicular_steatosis_donor_liver_biopsy,
                             macro_vesicular_steatosis_donor_liver_biopsy,hdl_pre,hdl_post,phos_pre,
                             mg_pre,tg_pre,tg_post,ldh_pre,ldh_post,
                             ldl_pre,ldl_post,gamma_gt_pre,gamma_gt_post,t_chol_pre,t_chol_post,
                             micro_vesicular_steatosis_number_donor_liver_biopsy
                             ,reperfusion_syndrome,fibrosis_donor_liver_biopsy,
                             macro_vesicular_steatosis_number_donor_liver_biopsy))

tb <- subset(tb, select = -c(bun_pre,bun_post,pt_pre,pt_post))
tb<- tb[!(is.na(tb$time_to_death)), ]
attach(tb)
## Task Definition
tsk_s <- as_task_surv(tb, time = "time_to_death", event = "status", type = "right")
# imputes values based on histogram
imputer_hist = po("imputehist",
                  affect_columns = selector_name(c("meld_peld_score", "alk_pre", "alk_post","alb_pre", "alb_post"
                                                   ,"creatinine_pre","creatinine_post","na_pre","na_post"
                                                   ,"direct_billi_pre","direct_billi_post","total_billi_pre","total_billi_post"
                                                   ,"inr_pre","inr_post","phos_post","mg_post","ppt_pre","ptt_post","bmi"
                                                   ,"DeRitis_ratio_pre","DeRitis_ratio_post","cold_time","warm_time"
                                                   ,"graft_weight","max_Tacrolimus")))
# imputes values using the mod
imputer_mode = po("imputemode",
                  affect_columns = selector_name(c("biliary_anastomosis", "vasopressors","acute_kidney_disease_pm_history"
                                                   ,"diabetes_mellitus_pm_history","hypertension_pm_history","infections_pm_history"
                                                   ,"sbp_pm_history","cardiovascular_pm_history","previous_hospitalization_pm_history"
                                                   ,"Sirolimus")))

imputer_hist$train(list(tsk_s))[[1]]$missings()
#>                       time_to_death                              status 
#>                                   0                                   0 
#>                           Sirolimus     acute_kidney_disease_pm_history 
#>                                  19                                 107 
#>                                 age                 biliary_anastomosis 
#>                                   0                                 101 
#>           cardiovascular_pm_history        diabetes_mellitus_pm_history 
#>                                 107                                 107 
#>                              gender             hypertension_pm_history 
#>                                   0                                 107 
#>               infections_pm_history previous_hospitalization_pm_history 
#>                                 107                                 475 
#>                      sbp_pm_history                        vasopressors 
#>                                 107                                  37 
#>                  DeRitis_ratio_post                   DeRitis_ratio_pre 
#>                                   0                                   0 
#>                            alb_post                             alb_pre 
#>                                   0                                   0 
#>                            alk_post                             alk_pre 
#>                                   0                                   0 
#>                                 bmi                           cold_time 
#>                                   0                                   0 
#>                     creatinine_post                      creatinine_pre 
#>                                   0                                   0 
#>                   direct_billi_post                    direct_billi_pre 
#>                                   0                                   0 
#>                        graft_weight                            inr_post 
#>                                   0                                   0 
#>                             inr_pre                      max_Tacrolimus 
#>                                   0                                   0 
#>                     meld_peld_score                             mg_post 
#>                                   0                                   0 
#>                             na_post                              na_pre 
#>                                   0                                   0 
#>                           phos_post                             ppt_pre 
#>                                   0                                   0 
#>                            ptt_post                    total_billi_post 
#>                                   0                                   0 
#>                     total_billi_pre                           warm_time 
#>                                   0                                   0
imputer_mode$train(list(tsk_s))[[1]]$missings()
#>                       time_to_death                              status 
#>                                   0                                   0 
#>                  DeRitis_ratio_post                   DeRitis_ratio_pre 
#>                                 156                                 234 
#>                                 age                            alb_post 
#>                                   0                                  52 
#>                             alb_pre                            alk_post 
#>                                 258                                  64 
#>                             alk_pre                                 bmi 
#>                                 155                                 137 
#>                           cold_time                     creatinine_post 
#>                                 129                                  32 
#>                      creatinine_pre                   direct_billi_post 
#>                                 201                                  67 
#>                    direct_billi_pre                              gender 
#>                                 185                                   0 
#>                        graft_weight                            inr_post 
#>                                 118                                  79 
#>                             inr_pre                      max_Tacrolimus 
#>                                 210                                 481 
#>                     meld_peld_score                             mg_post 
#>                                 135                                 136 
#>                             na_post                              na_pre 
#>                                  33                                 206 
#>                           phos_post                             ppt_pre 
#>                                  91                                 463 
#>                            ptt_post                    total_billi_post 
#>                                 116                                  66 
#>                     total_billi_pre                           warm_time 
#>                                 163                                 128 
#>                           Sirolimus     acute_kidney_disease_pm_history 
#>                                   0                                   0 
#>                 biliary_anastomosis           cardiovascular_pm_history 
#>                                   0                                   0 
#>        diabetes_mellitus_pm_history             hypertension_pm_history 
#>                                   0                                   0 
#>               infections_pm_history previous_hospitalization_pm_history 
#>                                   0                                   0 
#>                      sbp_pm_history                        vasopressors 
#>                                   0                                   0
impute_graph = imputer_hist %>>% imputer_mode
surv_task = impute_graph$train(tsk_s)[[1]]

set.seed(42)
part = partition(surv_task, ratio = 0.8)

learner_penalized=lrn("surv.penalized", lambda1=25.8)

learner =
  ppl("crankcompositor",
      learner = learner_penalized,
      response = TRUE, method = "median", overwrite = FALSE) |>
  as_learner()

learner$train(surv_task, part$train)
#> # nonzero coefficients: 38# nonzero coefficients: 25          # nonzero coefficients: 27          # nonzero coefficients: 28          # nonzero coefficients: 27          # nonzero coefficients: 27          # nonzero coefficients: 26          # nonzero coefficients: 25          # nonzero coefficients: 25          # nonzero coefficients: 25          # nonzero coefficients: 25          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 25          # nonzero coefficients: 25          # nonzero coefficients: 24          # nonzero coefficients: 25          # nonzero coefficients: 25          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          # nonzero coefficients: 24          
learner::coef(learner$model)
#> Error in loadNamespace(x): there is no package called 'learner'
learner$selected_features()
#> Error in eval(expr, envir, enclos): attempt to apply non-function

Created on 2024-09-10 with reprex v2.1.0


Solution

  • This is similar to the question here. You have to use:

    penalized::coef(learner$model)
    

    instead of

    learner::coef(learner$model)
    

    the $selected_features() function is WIP, see PR.

    Also we now have a better pipeline to estimate the response (survival time) , check out the response composition pipeline