I would like to reproduce the fitting (training and subsequent prediction) of an XGBoost model in both mlr3 and xgboost package. See the following example using the Lung dataset, and predicting on the training dataset for simplicity. The linear predictors from xgboost (xgb_pred) and mlr3 (mlr3_xgb$lp) are not quite the same. Any advice on why this might be the case would be greatly appreciated (hopefully it is just a glitch in my coding or a lack of understanding).
library(mlr3)
library(mlr3pipelines)
library(mlr3extralearners)
library(mlr3proba)
library(xgboost)
## mlr3 as an example ----
task_lung = tsk('lung')
lung = task_lung$data()
xgb_basic = as_learner(
po("encode") %>>% lrn("surv.xgboost.cox", eta = 0.0103))
set.seed(123)
xgb_basic$train(task_lung)
mlr3_xgb = xgb_basic$predict(task_lung)
## use xgboost package -----
# labels to be attached to dataset
label <- ifelse(lung$status == 0, lung$time, -lung$time) # label
y_lower_bound = lung$time
y_upper_bound = ifelse(lung$status==0, +Inf, lung$time)
xgb_data=model.matrix(~.+0, data = lung[,-c(1,2),with=F]) # one hot coding
# Data matrix
dmat = xgb.DMatrix(xgb_data, label=label) # for cox
params <- list(objective='survival:cox', # train
eval_metric='cox-nloglik',
learning_rate=0.0103) #aka eta
set.seed(123)
bst <- xgb.train(params=params,
data = dmat,
nrounds=1,
watchlist=list(train = dmat, eval=dmat))
#> [1] train-cox-nloglik:3.896049 eval-cox-nloglik:3.896049
xgb_pred = predict(bst, newdata=dmat)
round(exp(mlr3_xgb$lp),3)
#> [1] 0.500 0.510 0.496 0.510 0.500 0.503 0.500 0.498 0.498 0.499 0.498 0.510
#> [13] 0.500 0.501 0.498 0.499 0.499 0.513 0.502 0.513 0.499 0.513 0.510 0.513
#> [25] 0.499 0.496 0.513 0.499 0.500 0.505 0.500 0.500 0.513 0.505 0.498 0.499
#> [37] 0.496 0.499 0.498 0.498 0.500 0.499 0.499 0.500 0.503 0.499 0.510 0.510
#> [49] 0.495 0.499 0.505 0.499 0.499 0.513 0.504 0.499 0.498 0.499 0.505 0.498
#> [61] 0.503 0.503 0.504 0.496 0.500 0.499 0.498 0.499 0.513 0.499 0.505 0.510
#> [73] 0.513 0.499 0.495 0.499 0.505 0.499 0.503 0.513 0.505 0.503 0.500 0.513
#> [85] 0.510 0.500 0.502 0.505 0.505 0.499 0.513 0.500 0.505 0.510 0.496 0.496
#> [97] 0.499 0.503 0.505 0.496 0.499 0.503 0.513 0.505 0.513 0.499 0.498 0.499
#> [109] 0.503 0.505 0.500 0.510 0.513 0.502 0.502 0.499 0.495 0.504 0.500 0.499
#> [121] 0.498 0.495 0.503 0.499 0.499 0.499 0.501 0.499 0.500 0.503 0.499 0.513
#> [133] 0.499 0.495 0.498 0.499 0.501 0.495 0.505 0.498 0.510 0.503 0.505 0.498
#> [145] 0.499 0.500 0.499 0.495 0.502 0.499 0.513 0.495 0.495 0.495 0.510 0.495
#> [157] 0.505 0.502 0.498 0.513 0.500 0.495 0.496 0.499 0.499 0.500 0.499 0.499
round(xgb_pred,3)
#> [1] 0.496 0.499 0.495 0.501 0.497 0.495 0.502 0.499 0.499 0.496 0.495 0.501
#> [13] 0.496 0.496 0.496 0.498 0.496 0.495 0.499 0.495 0.496 0.495 0.501 0.495
#> [25] 0.495 0.497 0.495 0.496 0.500 0.501 0.500 0.495 0.501 0.503 0.500 0.505
#> [37] 0.498 0.505 0.505 0.496 0.498 0.496 0.496 0.501 0.499 0.505 0.505 0.495
#> [49] 0.505 0.505 0.500 0.496 0.500 0.495 0.498 0.496 0.498 0.499 0.503 0.498
#> [61] 0.495 0.506 0.510 0.506 0.496 0.505 0.505 0.503 0.495 0.496 0.505 0.503
#> [73] 0.498 0.505 0.500 0.505 0.496 0.499 0.498 0.501 0.503 0.506 0.505 0.495
#> [85] 0.495 0.502 0.495 0.500 0.497 0.497 0.495 0.496 0.496 0.499 0.505 0.507
#> [97] 0.496 0.498 0.498 0.502 0.496 0.499 0.501 0.503 0.495 0.505 0.500 0.510
#> [109] 0.499 0.500 0.495 0.495 0.495 0.495 0.510 0.507 0.500 0.506 0.502 0.505
#> [121] 0.503 0.505 0.498 0.510 0.505 0.510 0.496 0.507 0.498 0.499 0.505 0.495
#> [133] 0.500 0.510 0.505 0.506 0.498 0.506 0.497 0.505 0.500 0.498 0.500 0.507
#> [145] 0.496 0.505 0.510 0.498 0.505 0.500 0.495 0.505 0.510 0.510 0.501 0.510
#> [157] 0.503 0.505 0.500 0.498 0.501 0.507 0.505 0.505 0.506 0.501 0.505 0.505
Created on 2024-08-30 by the reprex package (v2.0.1)
Session infosessioninfo::session_info()
#> ─ Session info ───────────────────────────────────────────────────────────────
#> setting value
#> version R version 4.1.1 (2021-08-10)
#> os macOS Big Sur 10.16
#> system x86_64, darwin17.0
#> ui X11
#> language (EN)
#> collate en_US.UTF-8
#> ctype en_US.UTF-8
#> tz Australia/Adelaide
#> date 2024-08-30
#>
#> ─ Packages ───────────────────────────────────────────────────────────────────
#> package * version date lib
#> backports 1.5.0 2024-05-23 [1]
#> checkmate 2.3.1 2023-12-04 [1]
#> cli 3.6.3 2024-06-21 [1]
#> codetools 0.2-18 2020-11-04 [2]
#> colorspace 2.1-0 2023-01-23 [1]
#> crayon 1.4.1 2021-02-08 [2]
#> data.table 1.15.4 2024-03-30 [1]
#> dictionar6 0.1.3 2021-09-13 [1]
#> digest 0.6.36 2024-06-23 [1]
#> distr6 1.8.4 2024-06-13 [1]
#> dplyr 1.1.3 2023-09-03 [1]
#> evaluate 0.24.0 2024-06-10 [1]
#> fansi 1.0.6 2023-12-08 [1]
#> fastmap 1.1.0 2021-01-25 [2]
#> fs 1.5.0 2020-07-31 [2]
#> future 1.33.2 2024-03-26 [1]
#> generics 0.1.3 2022-07-05 [1]
#> ggplot2 3.5.1 2024-04-23 [1]
#> globals 0.16.3 2024-03-08 [1]
#> glue 1.7.0 2024-01-09 [1]
#> gtable 0.3.5 2024-04-22 [1]
#> highr 0.9 2021-04-16 [2]
#> htmltools 0.5.6 2023-08-10 [1]
#> jsonlite 1.7.2 2020-12-09 [2]
#> knitr 1.33 2021-04-24 [2]
#> lattice 0.20-44 2021-05-02 [2]
#> lgr 0.4.4 2022-09-05 [1]
#> lifecycle 1.0.4 2023-11-07 [1]
#> listenv 0.9.1 2024-01-29 [1]
#> magrittr 2.0.3 2022-03-30 [1]
#> Matrix 1.3-4 2021-06-01 [2]
#> mlr3 * 0.20.0 2024-06-28 [1]
#> mlr3extralearners * 0.8.0-9000 2024-06-15 [1]
#> mlr3misc 0.15.1 2024-06-24 [1]
#> mlr3pipelines * 0.6.0 2024-07-16 [1]
#> mlr3proba * 0.6.3 2024-06-13 [1]
#> mlr3viz 0.9.0 2024-07-01 [1]
#> munsell 0.5.1 2024-04-01 [1]
#> ooplah 0.2.0 2022-01-21 [1]
#> palmerpenguins 0.1.1 2022-08-15 [1]
#> paradox 1.0.1 2024-07-09 [1]
#> parallelly 1.37.1 2024-02-29 [1]
#> param6 0.2.4 2023-11-22 [1]
#> pillar 1.9.0 2023-03-22 [1]
#> pkgconfig 2.0.3 2019-09-22 [2]
#> R6 2.5.1 2021-08-19 [1]
#> Rcpp 1.0.12 2024-01-09 [1]
#> reprex 2.0.1 2021-08-05 [1]
#> RhpcBLASctl 0.23-42 2023-02-11 [1]
#> rlang 1.1.4 2024-06-04 [1]
#> rmarkdown 2.10 2021-08-06 [2]
#> rstudioapi 0.15.0 2023-07-07 [1]
#> scales 1.3.0 2023-11-28 [1]
#> sessioninfo 1.1.1 2018-11-05 [2]
#> set6 0.2.6 2023-11-22 [1]
#> stringi 1.7.3 2021-07-16 [2]
#> stringr 1.5.0 2022-12-02 [1]
#> survival 3.7-0 2024-06-05 [1]
#> tibble 3.2.1 2023-03-20 [1]
#> tidyselect 1.2.0 2022-10-10 [1]
#> utf8 1.2.4 2023-10-22 [1]
#> uuid 1.2-0 2024-01-14 [1]
#> vctrs 0.6.5 2023-12-01 [1]
#> withr 3.0.0 2024-01-16 [1]
#> xfun 0.25 2021-08-06 [2]
#> xgboost * 1.7.8.1 2024-07-24 [1]
#> yaml 2.2.1 2020-02-01 [2]
#> source
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> Github (xoopR/distr6@95d7359)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> Github (mlr-org/mlr3extralearners@6dc6965)
#> CRAN (R 4.1.1)
#> Github (mlr-org/mlr3pipelines@c542a26)
#> Github (mlr-org/mlr3proba@5205752)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> Github (xoopR/param6@0fa3577)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> Github (xoopR/set6@a901255)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.2)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#> CRAN (R 4.1.1)
#> CRAN (R 4.1.0)
#>
#> [1] /Users/Lee/Library/R/x86_64/4.1/library
#> [2] /Library/Frameworks/R.framework/Versions/4.1/Resources/library
Thank you.
Difficult to say for sure, though I am quite positive that we are doing the correct data transformation. Some possible things to check out:
xgboost::xgb.DMatrix
? For the cox
objective, we just need to negate the label (observed times) for the censored observations: see here for our internal helper functions for xgboost. And a similar question that has some code that does exactly that conversion. I think that part is the same.nrounds = 1
in both versions? (recently default changed to 1000, see NEWS)tsk("gbcs")
to simplify things when investigating such things.watchlist
in the manual version, in the mlr3
it's empty (NULL
). In the newest version we support internal validation and early stopping of xgboost btw, but for here I would make sure every parameter is the same and that the output raw
xgboost model are the same (predictions would then surely follow).