rggplot2dalex

Averaging results from 'rr_profile' for multiple iterations of statistical models


Does anyone have any experience obtained averaged results from multiple model_profile outputs from DALEX?

EDIT

To clarify, the model_profile function outputs partial dependence plots for one model. However, I have a list of 500 models that have been generated using different testing/training splits, and downsampled data.

Averaging results for model_profile with variable_type = 'categorical' was relatively straight forward, as each category was extracted as one value from $agr_profiles, which could be merged together into a large data frame.

However, I am not sure how exactly to go about averaging the entire profile 500 times to generate averaged values for the model_profile with variable_type = 'continuous'.

Reproducible example

library("ranger") 
library(DALEX)

trainIdx_1 <- sample(nrow(titanic_imputed), 2/3 * nrow(titanic_imputed)) 
trainData_1 <- titanic_imputed[trainIdx_1, ]

trainIdx_2 <- sample(nrow(titanic_imputed), 2/3 * nrow(titanic_imputed)) 
trainData_2 <- titanic_imputed[trainIdx_2, ]

titanic_ranger_model_1 <- ranger(survived~., data = trainData_1, num.trees = 50,
                               probability = TRUE) 
titanic_ranger_model_2 <- ranger(survived~., data = trainData_2, num.trees = 50,
                                 probability = TRUE)

exp_1 <- explain(titanic_ranger_model_1, data = trainData_1) 
exp_2 <- explain(titanic_ranger_model_2, data = trainData_2)

model_profile_1 <- model_profile(exp_1, type = "partial") 
model_profile_2 <- model_profile(exp_2, type = "partial")

How could I obtain the average profile from model_profile_1 and model_profile_2?


Solution

  • I was able to get a decent looking mean line by joining both profiles data as one dataset, and plotting a line using ggplot's stat_summary_bin. It's not perfect, and may require some playing around with bin numbers for best results, but it works for me. Example code and output shown below. It looks better if more iterations are used

    agrData = rbind(model_profile_1$agr_profiles, model_profile_2$agr_profiles)
    plot(agrData) + stat_summary_bin(mapping = aes(y = `_yhat_`), fun = "mean", geom = "line")
    

    Black line is mean plot