rvisualizationshap

sv_dependence plot with shapviz and same feature for x axis and colour


I am sorry to ask this question again (first attempt here: shapvix sv_dependence color_var based on same feature), but it was closed because not enough details were provided. Below is an updated elaboration including a reproducible example.

My issue is that I would like to use the sv_dependence function from shapviz to make dependence plots coloured based on the values of the variable on the x axis, just for aesthetic reasons.

library(shapviz)
library(ggplot2)
library(xgboost)

set.seed(1)

xvars <- c("log_carat", "cut", "color", "clarity")
X <- diamonds |> 
  transform(log_carat = log(carat)) |> 
  subset(select = xvars)
head(X)


# Fit (untuned) model
fit <- xgb.train(
  params = list(learning_rate = 0.1, nthread = 1), 
  data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price), nthread = 1),
  nrounds = 65
)

# SHAP analysis
X_explain <- X[sample(nrow(X), 2000), ]
shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)

# Not colored based on "log_carat" values
sv_dependence(shp, v = "log_carat", alpha=0.5, color_var= "log_carat")

# Colored based on "clarity" values
sv_dependence(shp, v = "log_carat", alpha=0.5, color_var= "clarity")

As someone of you pointed out in the comments to the last post, the goal of the function's color_var parameter is to highlight interactions effects between predictors. My goal instead is purely visualisation, meaning that the dependency relationship between shap values and the predictors' values emerges even better if the shap values are coloured according to the variable's scale. The color_var parameter, while not designed for this purpose, used to work for me to obtain this, but now it doesn't (maybe some package update?). I was wondering if you can suggest a workaround using this same function, or some ggplpot strategy to obtain the same. Thank you again for your support, hope this additional information clarifies my issue.


Solution

  • Can you just color the dots? E.g.

    #install.packages("shapviz")
    library(ggplot2)
    library(xgboost)
    library(shapviz)
    
    set.seed(1)
    
    xvars <- c("log_carat", "cut", "color", "clarity")
    X <- diamonds |> 
      transform(log_carat = log(carat)) |> 
      subset(select = xvars)
    head(X)
    #>   log_carat       cut color clarity
    #> 1 -1.469676     Ideal     E     SI2
    #> 2 -1.560648   Premium     E     SI1
    #> 3 -1.469676      Good     E     VS1
    #> 4 -1.237874   Premium     I     VS2
    #> 5 -1.171183      Good     J     SI2
    #> 6 -1.427116 Very Good     J    VVS2
    
    
    # Fit (untuned) model
    fit <- xgb.train(
      params = list(learning_rate = 0.1, nthread = 1), 
      data = xgb.DMatrix(data.matrix(X), label = log(diamonds$price), nthread = 1),
      nrounds = 65
    )
    
    # SHAP analysis
    X_explain <- X[sample(nrow(X), 2000), ]
    shp <- shapviz(fit, X_pred = data.matrix(X_explain), X = X_explain)
    
    sv_dependence(shp, v = "log_carat", alpha=0.5, color_var = NULL) +
      geom_jitter(aes(color = log_carat))
    

    Created on 2025-06-03 with reprex v2.1.1

    Or have I misunderstood?