I'm using marginaleffects
to compute the marginal treatment effects from a Bayesian logit model fit using brms
:
fit <- brm(y ~ (treat|state) + treat + age + sex + race, family = "bernoulli", data = dat)
mfx <- avg_slopes(fit, variable = "treat", by = "state")
The avg_slopes
function returns a summary table, including the Estimate and 95% CI of the treatment effect for each state.
I would like to do some diagnostics on the chains for each marginal effect. At a minimum, I'd like to calculate the rhats for each treatment effect. I know how to do this if I can extract the draws by chain that correspond to each marginal effect - so is there a way to do this? I know there is the posterior_draws
function, but this samples from the posterior, so isn't helpful for the purpose of extracting the chains or computing rhats.
Consider using posterior_draws()
to output rvar
objects which can
then be processed by the posterior
package:
library(brms)
library(marginaleffects)
library(posterior)
mod <- brm(mpg ~ wt * hp, data = mtcars)
slo <- avg_slopes(mod)
dra <- posterior_draws(slo, shape = "rvar")
sapply(dra$rvar, posterior::rhat)
hp.main_marginaleffect.dY/dX wt.main_marginaleffect.dY/dX
1.001199 1.000394
sapply(dra$rvar, posterior::ess_basic)
hp.main_marginaleffect.dY/dX wt.main_marginaleffect.dY/dX
4119.939 3711.978
sapply(dra$rvar, posterior::mcse_mean)
hp.main_marginaleffect.dY/dX wt.main_marginaleffect.dY/dX
0.000121686 0.009075347