I am trying to build a survxai explainer from a survival model built with mlr3proba. I'm having trouble creating the predict_function necessary for the explainer. Has anyone ever tried to build something like this?
So far, my code is the following:
require(survxai)
require(survival)
require(survivalmodels)
require(mlr3proba)
require(mlr3pipelines)
create_pipeops <- function(learner) {
GraphLearner$new(po("encode") %>>% po("scale") %>>% po("learner", learner))
}
fit<-lrn("surv.deepsurv")
fit<-create_pipeops(fit)
data<-veteran
survival_task<-TaskSurv$new("veteran", veteran, time = "time", event = "status")
fit$train(survival_task)
predict_function<-function(model, newdata, times=NULL){
if(!is.data.frame(newdata)){
newdata <- data.frame(newdata)
}
surv_task<-TaskSurv$new("task", newdata, time = "time",
event = "status")
pred<-model$predict(surv_task)
mat<-matrix(pred$data$distr, nrow = nrow(pred$data$distr))
colnames(mat)<-colnames(pred$data$distr)
return(mat)
}
explainer<-survxai::explain(model = learner$model, data = veteran[,-c(3,4)],
y = Surv(veteran$time, veteran$status),
predict_function = predict_function)
pred_breakdown<-prediction_breakdown(explainer, veteran[1,])
It throws the following error: Error in [.data.table
(r6_private(backend)$.data, , event, with = FALSE) :
column(s) not found: status, but I suspect that once that one is solved there may be more. I don't fully understand the structure of the object that the function returns.
In the predict_function, I included the times
argument because according to the R help page, the function must take the three arguments.
Working example with randomForestSRC here, you can just change surv.rfsrc
to surv.deepsurv
for your example. BTW we are planning on implementing this within mlr3proba soon, or I might just add it directly to survivalmodels, still deciding!
library(mlr3proba)
#> Loading required package: mlr3
#> Warning: package 'mlr3' was built under R version 4.1.3
library(mlr3extralearners)
#>
#> Attaching package: 'mlr3extralearners'
#> The following objects are masked from 'package:mlr3':
#>
#> lrn, lrns
library(survxai)
#> Loading required package: prodlim
#> Welcome to survxai (version: 0.2.1).
#> Information about the package can be found in the GitHub repository: https://github.com/MI2DataLab/survxai
library(survival)
data(pbc, package = "randomForestSRC")
pbc <- pbc[complete.cases(pbc), ]
task <- as_task_surv(pbc, event = "status", time = "days")
split <- partition(task)
predict_times <- function(model, data, times) {
t(model$predict_newdata(data)$distr$survival(times))
}
model <- lrn("surv.rfsrc")$train(task, row_ids = split$train)
surve_cph <- explain(
model = model, data = pbc[, -c(1, 2)],
y = Surv(pbc$days, pbc$status),
predict_function = predict_times
)
prediction_breakdown(surve_cph, pbc[1, -c(1, 2)])
#> contribution
#> bili -35.079%
#> edema -10.278%
#> ascites -5.505%
#> copper -1.084%
#> stage -0.773%
#> prothrombin -0.421%
#> albumin -0.247%
#> sgot -0.143%
#> hepatom -0.098%
#> spiders -0.086%
#> alk -0.043%
#> trig -0.041%
#> age -0.035%
Created on 2022-06-07 by the reprex package (v2.0.1)