rsurvival-analysissurvivalmlr3dalex

survxai explainer with an mlr3proba model


I am trying to build a survxai explainer from a survival model built with mlr3proba. I'm having trouble creating the predict_function necessary for the explainer. Has anyone ever tried to build something like this?

So far, my code is the following:

require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)

create_pipeops <- function(learner) {
  GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}

fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)

data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)

predict_function<-function(model, newdata, times=NULL){
  if(!is.data.frame(newdata)){
    newdata <- data.frame(newdata)
  }
  surv_task<-TaskSurv$new("task", newdata, time = "time", 
                          event = "status")
  pred<-model$predict(surv_task)
  mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
  colnames(mat)<-colnames(pred$data$distr)
  return(mat)
}

explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
                            y = Surv(veteran$time, veteran$status),
                            predict_function = predict_function)

pred_breakdown<-prediction_breakdown(explainer, veteran[1,])

It throws the following error: Error in [.data.table(r6_private(backend)$.data, , event, with = FALSE) : column(s) not found: status, but I suspect that once that one is solved there may be more. I don't fully understand the structure of the object that the function returns.

In the predict_function, I included the times argument because according to the R help page, the function must take the three arguments.


Solution

  • Working example with randomForestSRC here, you can just change surv.rfsrc to surv.deepsurv for your example. BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!

    library(mlr3proba)
    #> Loading required package: mlr3
    #> Warning: package 'mlr3' was built under R version 4.1.3
    library(mlr3extralearners)
    #> 
    #> Attaching package: 'mlr3extralearners'
    #> The following objects are masked from 'package:mlr3':
    #> 
    #>     lrn, lrns
    library(survxai)
    #> Loading required package: prodlim
    #> Welcome to survxai (version: 0.2.1).
    #> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
    library(survival)
    data(pbc, package = "randomForestSRC")
    pbc <- pbc[complete.cases(pbc), ]
    task <- as_task_surv(pbc, event = "status", time = "days")
    split <- partition(task)
    predict_times <- function(model, data, times) {
      t(model$predict_newdata(data)$distr$survival(times))
    }
    model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
    surve_cph <- explain(
      model = model, data = pbc[, -c(1, 2)],
      y = Surv(pbc$days, pbc$status),
      predict_function = predict_times
    )
    prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
    #>             contribution
    #> bili            -35.079%
    #> edema           -10.278%
    #> ascites          -5.505%
    #> copper           -1.084%
    #> stage            -0.773%
    #> prothrombin      -0.421%
    #> albumin          -0.247%
    #> sgot             -0.143%
    #> hepatom          -0.098%
    #> spiders          -0.086%
    #> alk              -0.043%
    #> trig             -0.041%
    #> age              -0.035%
    

    Created on 2022-06-07 by the reprex package (v2.0.1)