xgboostmlr3proba

xgboost prediction (of cox linear predictor) is different from mlr3 xgboost.cox


I would like to reproduce the fitting (training and subsequent prediction) of an XGBoost model in both mlr3 and xgboost package. See the following example using the Lung dataset, and predicting on the training dataset for simplicity. The linear predictors from xgboost (xgb_pred) and mlr3 (mlr3_xgb$lp) are not quite the same. Any advice on why this might be the case would be greatly appreciated (hopefully it is just a glitch in my coding or a lack of understanding).

library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(mlr3proba)
library(xgboost)

## mlr3 as an example ----
task_lung = tsk('lung')
lung = task_lung$data()
xgb_basic = as_learner(
  po("encode") %>>%  lrn("surv.xgboost.cox",  eta = 0.0103))

set.seed(123)
xgb_basic$train(task_lung)
mlr3_xgb = xgb_basic$predict(task_lung)


## use xgboost package -----
# labels to be attached to dataset
label <- ifelse(lung$status == 0, lung$time, -lung$time) # label
y_lower_bound = lung$time
y_upper_bound = ifelse(lung$status==0, +Inf, lung$time)

xgb_data=model.matrix(~.+0, data = lung[,-c(1,2),with=F]) # one hot coding 

# Data matrix
dmat = xgb.DMatrix(xgb_data, label=label) # for cox


params <- list(objective='survival:cox',  # train
               eval_metric='cox-nloglik',
               learning_rate=0.0103)  #aka eta

set.seed(123) 
bst <- xgb.train(params=params, 
                 data = dmat, 
                 nrounds=1, 
                 watchlist=list(train = dmat, eval=dmat))
#> [1]  train-cox-nloglik:3.896049  eval-cox-nloglik:3.896049

xgb_pred = predict(bst, newdata=dmat)

round(exp(mlr3_xgb$lp),3)
#>   [1] 0.500 0.510 0.496 0.510 0.500 0.503 0.500 0.498 0.498 0.499 0.498 0.510
#>  [13] 0.500 0.501 0.498 0.499 0.499 0.513 0.502 0.513 0.499 0.513 0.510 0.513
#>  [25] 0.499 0.496 0.513 0.499 0.500 0.505 0.500 0.500 0.513 0.505 0.498 0.499
#>  [37] 0.496 0.499 0.498 0.498 0.500 0.499 0.499 0.500 0.503 0.499 0.510 0.510
#>  [49] 0.495 0.499 0.505 0.499 0.499 0.513 0.504 0.499 0.498 0.499 0.505 0.498
#>  [61] 0.503 0.503 0.504 0.496 0.500 0.499 0.498 0.499 0.513 0.499 0.505 0.510
#>  [73] 0.513 0.499 0.495 0.499 0.505 0.499 0.503 0.513 0.505 0.503 0.500 0.513
#>  [85] 0.510 0.500 0.502 0.505 0.505 0.499 0.513 0.500 0.505 0.510 0.496 0.496
#>  [97] 0.499 0.503 0.505 0.496 0.499 0.503 0.513 0.505 0.513 0.499 0.498 0.499
#> [109] 0.503 0.505 0.500 0.510 0.513 0.502 0.502 0.499 0.495 0.504 0.500 0.499
#> [121] 0.498 0.495 0.503 0.499 0.499 0.499 0.501 0.499 0.500 0.503 0.499 0.513
#> [133] 0.499 0.495 0.498 0.499 0.501 0.495 0.505 0.498 0.510 0.503 0.505 0.498
#> [145] 0.499 0.500 0.499 0.495 0.502 0.499 0.513 0.495 0.495 0.495 0.510 0.495
#> [157] 0.505 0.502 0.498 0.513 0.500 0.495 0.496 0.499 0.499 0.500 0.499 0.499
round(xgb_pred,3)
#>   [1] 0.496 0.499 0.495 0.501 0.497 0.495 0.502 0.499 0.499 0.496 0.495 0.501
#>  [13] 0.496 0.496 0.496 0.498 0.496 0.495 0.499 0.495 0.496 0.495 0.501 0.495
#>  [25] 0.495 0.497 0.495 0.496 0.500 0.501 0.500 0.495 0.501 0.503 0.500 0.505
#>  [37] 0.498 0.505 0.505 0.496 0.498 0.496 0.496 0.501 0.499 0.505 0.505 0.495
#>  [49] 0.505 0.505 0.500 0.496 0.500 0.495 0.498 0.496 0.498 0.499 0.503 0.498
#>  [61] 0.495 0.506 0.510 0.506 0.496 0.505 0.505 0.503 0.495 0.496 0.505 0.503
#>  [73] 0.498 0.505 0.500 0.505 0.496 0.499 0.498 0.501 0.503 0.506 0.505 0.495
#>  [85] 0.495 0.502 0.495 0.500 0.497 0.497 0.495 0.496 0.496 0.499 0.505 0.507
#>  [97] 0.496 0.498 0.498 0.502 0.496 0.499 0.501 0.503 0.495 0.505 0.500 0.510
#> [109] 0.499 0.500 0.495 0.495 0.495 0.495 0.510 0.507 0.500 0.506 0.502 0.505
#> [121] 0.503 0.505 0.498 0.510 0.505 0.510 0.496 0.507 0.498 0.499 0.505 0.495
#> [133] 0.500 0.510 0.505 0.506 0.498 0.506 0.497 0.505 0.500 0.498 0.500 0.507
#> [145] 0.496 0.505 0.510 0.498 0.505 0.500 0.495 0.505 0.510 0.510 0.501 0.510
#> [157] 0.503 0.505 0.500 0.498 0.501 0.507 0.505 0.505 0.506 0.501 0.505 0.505

Created on 2024-08-30 by the reprex package (v2.0.1)

Session info
sessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#>  setting  value                       
#>  version  R version 4.1.1 (2021-08-10)
#>  os       macOS Big Sur 10.16         
#>  system   x86_64, darwin17.0          
#>  ui       X11                         
#>  language (EN)                        
#>  collate  en_US.UTF-8                 
#>  ctype    en_US.UTF-8                 
#>  tz       Australia/Adelaide          
#>  date     2024-08-30                  
#> 
#> ─ Packages ───────────────────────────────────────────────────────────────────
#>  package           * version    date       lib
#>  backports           1.5.0      2024-05-23 [1]
#>  checkmate           2.3.1      2023-12-04 [1]
#>  cli                 3.6.3      2024-06-21 [1]
#>  codetools           0.2-18     2020-11-04 [2]
#>  colorspace          2.1-0      2023-01-23 [1]
#>  crayon              1.4.1      2021-02-08 [2]
#>  data.table          1.15.4     2024-03-30 [1]
#>  dictionar6          0.1.3      2021-09-13 [1]
#>  digest              0.6.36     2024-06-23 [1]
#>  distr6              1.8.4      2024-06-13 [1]
#>  dplyr               1.1.3      2023-09-03 [1]
#>  evaluate            0.24.0     2024-06-10 [1]
#>  fansi               1.0.6      2023-12-08 [1]
#>  fastmap             1.1.0      2021-01-25 [2]
#>  fs                  1.5.0      2020-07-31 [2]
#>  future              1.33.2     2024-03-26 [1]
#>  generics            0.1.3      2022-07-05 [1]
#>  ggplot2             3.5.1      2024-04-23 [1]
#>  globals             0.16.3     2024-03-08 [1]
#>  glue                1.7.0      2024-01-09 [1]
#>  gtable              0.3.5      2024-04-22 [1]
#>  highr               0.9        2021-04-16 [2]
#>  htmltools           0.5.6      2023-08-10 [1]
#>  jsonlite            1.7.2      2020-12-09 [2]
#>  knitr               1.33       2021-04-24 [2]
#>  lattice             0.20-44    2021-05-02 [2]
#>  lgr                 0.4.4      2022-09-05 [1]
#>  lifecycle           1.0.4      2023-11-07 [1]
#>  listenv             0.9.1      2024-01-29 [1]
#>  magrittr            2.0.3      2022-03-30 [1]
#>  Matrix              1.3-4      2021-06-01 [2]
#>  mlr3              * 0.20.0     2024-06-28 [1]
#>  mlr3extralearners * 0.8.0-9000 2024-06-15 [1]
#>  mlr3misc            0.15.1     2024-06-24 [1]
#>  mlr3pipelines     * 0.6.0      2024-07-16 [1]
#>  mlr3proba         * 0.6.3      2024-06-13 [1]
#>  mlr3viz             0.9.0      2024-07-01 [1]
#>  munsell             0.5.1      2024-04-01 [1]
#>  ooplah              0.2.0      2022-01-21 [1]
#>  palmerpenguins      0.1.1      2022-08-15 [1]
#>  paradox             1.0.1      2024-07-09 [1]
#>  parallelly          1.37.1     2024-02-29 [1]
#>  param6              0.2.4      2023-11-22 [1]
#>  pillar              1.9.0      2023-03-22 [1]
#>  pkgconfig           2.0.3      2019-09-22 [2]
#>  R6                  2.5.1      2021-08-19 [1]
#>  Rcpp                1.0.12     2024-01-09 [1]
#>  reprex              2.0.1      2021-08-05 [1]
#>  RhpcBLASctl         0.23-42    2023-02-11 [1]
#>  rlang               1.1.4      2024-06-04 [1]
#>  rmarkdown           2.10       2021-08-06 [2]
#>  rstudioapi          0.15.0     2023-07-07 [1]
#>  scales              1.3.0      2023-11-28 [1]
#>  sessioninfo         1.1.1      2018-11-05 [2]
#>  set6                0.2.6      2023-11-22 [1]
#>  stringi             1.7.3      2021-07-16 [2]
#>  stringr             1.5.0      2022-12-02 [1]
#>  survival            3.7-0      2024-06-05 [1]
#>  tibble              3.2.1      2023-03-20 [1]
#>  tidyselect          1.2.0      2022-10-10 [1]
#>  utf8                1.2.4      2023-10-22 [1]
#>  uuid                1.2-0      2024-01-14 [1]
#>  vctrs               0.6.5      2023-12-01 [1]
#>  withr               3.0.0      2024-01-16 [1]
#>  xfun                0.25       2021-08-06 [2]
#>  xgboost           * 1.7.8.1    2024-07-24 [1]
#>  yaml                2.2.1      2020-02-01 [2]
#>  source                                    
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/distr6@95d7359)             
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3extralearners@6dc6965)
#>  CRAN (R 4.1.1)                            
#>  Github (mlr-org/mlr3pipelines@c542a26)    
#>  Github (mlr-org/mlr3proba@5205752)        
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  Github (xoopR/param6@0fa3577)             
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  Github (xoopR/set6@a901255)               
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.2)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#>  CRAN (R 4.1.1)                            
#>  CRAN (R 4.1.0)                            
#> 
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library

Thank you.


Solution

  • Difficult to say for sure, though I am quite positive that we are doing the correct data transformation. Some possible things to check out: