rstanbrmsr-marginaleffects

Diagnostic of marginal effects from a brms model


I'm using marginaleffects to compute the marginal treatment effects from a Bayesian logit model fit using brms:

fit <- brm(y ~ (treat|state) + treat + age + sex + race, family = "bernoulli", data = dat)
mfx <- avg_slopes(fit, variable = "treat", by = "state")

The avg_slopes function returns a summary table, including the Estimate and 95% CI of the treatment effect for each state.

I would like to do some diagnostics on the chains for each marginal effect. At a minimum, I'd like to calculate the rhats for each treatment effect. I know how to do this if I can extract the draws by chain that correspond to each marginal effect - so is there a way to do this? I know there is the posterior_draws function, but this samples from the posterior, so isn't helpful for the purpose of extracting the chains or computing rhats.


Solution

  • Consider using posterior_draws() to output rvar objects which can then be processed by the posterior package:

    library(brms)
    library(marginaleffects)
    library(posterior)
    mod <- brm(mpg ~ wt * hp, data = mtcars)
    slo <- avg_slopes(mod)
    dra <- posterior_draws(slo, shape = "rvar")
    
    sapply(dra$rvar, posterior::rhat)
    
        hp.main_marginaleffect.dY/dX wt.main_marginaleffect.dY/dX 
                            1.001199                     1.000394 
    sapply(dra$rvar, posterior::ess_basic)
    
        hp.main_marginaleffect.dY/dX wt.main_marginaleffect.dY/dX 
                            4119.939                     3711.978 
    
    sapply(dra$rvar, posterior::mcse_mean)
    
        hp.main_marginaleffect.dY/dX wt.main_marginaleffect.dY/dX 
                         0.000121686                  0.009075347