rmachine-learningsurvival-analysistidymodelsyardstick

Calibration Plots for Survival Analysis


I am unable to create calibration plots for my survival analysis project. (oesophageal cancer dataset)

I have finalised my model (AORSF) after tuning:

aorsf_fit <- last_fit(
  final_aorsf_wf,
  split = initial_split(final_main, prop = 0.75),
  metrics = survival_metrics,
  eval_time = time_points_complete,
)

I then collect my prediction set, using the following code:

> predictions <- collect_predictions(aorsf_fit)
> predictions  
# A tibble: 687 × 6
   .pred            .pred_time id                .row   surv .config             
   <list>                <dbl> <chr>            <int> <Surv> <chr>               
 1 <tibble [8 × 3]>      107.  train/test split     3  57.8+ Preprocessor1_Model1
 2 <tibble [8 × 3]>       96.1 train/test split     7  96.8+ Preprocessor1_Model1
 3 <tibble [8 × 3]>       94.2 train/test split    11 130.0  Preprocessor1_Model1
 4 <tibble [8 × 3]>       11.4 train/test split    15   9.0  Preprocessor1_Model1
 5 <tibble [8 × 3]>      102.  train/test split    16  69.1+ Preprocessor1_Model1
 6 <tibble [8 × 3]>       37.8 train/test split    17  33.0  Preprocessor1_Model1
 7 <tibble [8 × 3]>      103.  train/test split    18 142.8  Preprocessor1_Model1
 8 <tibble [8 × 3]>       23.4 train/test split    20  13.5  Preprocessor1_Model1
 9 <tibble [8 × 3]>       89.7 train/test split    21 146.5  Preprocessor1_Model1
10 <tibble [8 × 3]>      107.  train/test split    23  60.1+ Preprocessor1_Model1
# ℹ 677 more rows

I am even able to get time dependent ROC curves (thanks to the roc_curve_survival() function):

predictions |> 
  roc_curve_survival(truth = surv, .pred)|> 
  filter(.eval_time == 60) |> 
  ggplot(aes(1 - specificity, sensitivity)) +
  geom_line() +
  theme_minimal()

However, I am unable to use the '.pred' list to create a calibration plot. I have tried using the code suggested within the link: An introduction to calibration with tidymodels

I want to be able to build the calibration plot by using the predicted survival data within the 'predictions' object. The '.pred' variable within the predictions object contains the following 3 rows: .eval_time .pred_survival .weight_censored

I am sure this, if extracted somehow, can be used to create the calibration plots as within the link and can then be used to show how well the model functions in contrast to observed survival! (I have observed survival for my dataset as well)

I have tried the following only, with no results:

predictions |> 
+ ggplot(aes(.pred_survival))

Solution

  • We don't yet have an interface to do that (as we do with regression or classification models).

    I wrote some unsupported code on tidymodels.org for an article there to create this figure. Here it is, look at the code chunk labeled prep-cal-for-shiny.

    It is not production-ized or tested but you might be able to use it for some analysis.