I am fairly new to R/STAN and I would like to code my own model in STAN code. The problem is that I don't know how to obtain the estimate__ values that conditional_effects(brmsfit) produces when using library(brms).
Here is an example of what I would like to obtain:
library(rstan)
library(brms)
N <- 10
y <- rnorm(10)
x <- rnorm(10)
df <- data.frame(x, y)
fit <- brm(y ~ x, data = df)
data <- conditional_effects(fit)
print(data[["x"]])
Which gives this output:
x y cond__ effect1__ estimate__ se__
1 -1.777412243 0.1417486 1 -1.777412243 0.08445399 0.5013894
2 -1.747889444 0.1417486 1 -1.747889444 0.08592914 0.4919022
3 -1.718366646 0.1417486 1 -1.718366646 0.08487412 0.4840257
4 -1.688843847 0.1417486 1 -1.688843847 0.08477227 0.4744689
5 -1.659321048 0.1417486 1 -1.659321048 0.08637019 0.4671830
6 -1.629798249 0.1417486 1 -1.629798249 0.08853233 0.4612196
7 -1.600275450 0.1417486 1 -1.600275450 0.08993511 0.4566040
8 -1.570752651 0.1417486 1 -1.570752651 0.08987979 0.4501722
9 -1.541229852 0.1417486 1 -1.541229852 0.09079337 0.4415650
10 -1.511707053 0.1417486 1 -1.511707053 0.09349952 0.4356073
11 -1.482184255 0.1417486 1 -1.482184255 0.09382594 0.4292237
12 -1.452661456 0.1417486 1 -1.452661456 0.09406637 0.4229115
13 -1.423138657 0.1417486 1 -1.423138657 0.09537000 0.4165933
14 -1.393615858 0.1417486 1 -1.393615858 0.09626168 0.4126735
15 -1.364093059 0.1417486 1 -1.364093059 0.09754818 0.4060894
16 -1.334570260 0.1417486 1 -1.334570260 0.09737763 0.3992320
17 -1.305047461 0.1417486 1 -1.305047461 0.09646332 0.3929951
18 -1.275524662 0.1417486 1 -1.275524662 0.09713718 0.3870211
19 -1.246001864 0.1417486 1 -1.246001864 0.09915170 0.3806628
20 -1.216479065 0.1417486 1 -1.216479065 0.10046754 0.3738948
21 -1.186956266 0.1417486 1 -1.186956266 0.10192677 0.3675363
22 -1.157433467 0.1417486 1 -1.157433467 0.10329695 0.3613282
23 -1.127910668 0.1417486 1 -1.127910668 0.10518868 0.3533583
24 -1.098387869 0.1417486 1 -1.098387869 0.10533191 0.3484098
25 -1.068865070 0.1417486 1 -1.068865070 0.10582833 0.3442075
26 -1.039342271 0.1417486 1 -1.039342271 0.10864510 0.3370518
27 -1.009819473 0.1417486 1 -1.009819473 0.10830692 0.3325785
28 -0.980296674 0.1417486 1 -0.980296674 0.11107417 0.3288747
29 -0.950773875 0.1417486 1 -0.950773875 0.11229667 0.3249769
30 -0.921251076 0.1417486 1 -0.921251076 0.11420108 0.3216303
31 -0.891728277 0.1417486 1 -0.891728277 0.11533604 0.3160908
32 -0.862205478 0.1417486 1 -0.862205478 0.11671013 0.3099456
33 -0.832682679 0.1417486 1 -0.832682679 0.11934724 0.3059504
34 -0.803159880 0.1417486 1 -0.803159880 0.12031792 0.3035792
35 -0.773637082 0.1417486 1 -0.773637082 0.12114301 0.2985330
36 -0.744114283 0.1417486 1 -0.744114283 0.12149371 0.2949334
37 -0.714591484 0.1417486 1 -0.714591484 0.12259197 0.2915398
38 -0.685068685 0.1417486 1 -0.685068685 0.12308763 0.2905327
39 -0.655545886 0.1417486 1 -0.655545886 0.12409683 0.2861451
40 -0.626023087 0.1417486 1 -0.626023087 0.12621634 0.2834400
41 -0.596500288 0.1417486 1 -0.596500288 0.12898609 0.2838938
42 -0.566977489 0.1417486 1 -0.566977489 0.12925969 0.2802667
43 -0.537454691 0.1417486 1 -0.537454691 0.13050938 0.2782553
44 -0.507931892 0.1417486 1 -0.507931892 0.12968382 0.2765127
45 -0.478409093 0.1417486 1 -0.478409093 0.13252478 0.2735946
46 -0.448886294 0.1417486 1 -0.448886294 0.13414535 0.2727640
47 -0.419363495 0.1417486 1 -0.419363495 0.13453109 0.2710725
48 -0.389840696 0.1417486 1 -0.389840696 0.13526957 0.2683500
49 -0.360317897 0.1417486 1 -0.360317897 0.13675913 0.2665745
50 -0.330795098 0.1417486 1 -0.330795098 0.13987067 0.2658021
51 -0.301272300 0.1417486 1 -0.301272300 0.14111051 0.2668740
52 -0.271749501 0.1417486 1 -0.271749501 0.14382292 0.2680711
53 -0.242226702 0.1417486 1 -0.242226702 0.14531118 0.2662193
54 -0.212703903 0.1417486 1 -0.212703903 0.14656473 0.2670958
55 -0.183181104 0.1417486 1 -0.183181104 0.14689102 0.2677249
56 -0.153658305 0.1417486 1 -0.153658305 0.14749250 0.2698547
57 -0.124135506 0.1417486 1 -0.124135506 0.14880275 0.2711767
58 -0.094612707 0.1417486 1 -0.094612707 0.15072864 0.2719037
59 -0.065089909 0.1417486 1 -0.065089909 0.15257772 0.2720895
60 -0.035567110 0.1417486 1 -0.035567110 0.15434018 0.2753563
61 -0.006044311 0.1417486 1 -0.006044311 0.15556588 0.2783308
62 0.023478488 0.1417486 1 0.023478488 0.15481341 0.2802336
63 0.053001287 0.1417486 1 0.053001287 0.15349716 0.2833364
64 0.082524086 0.1417486 1 0.082524086 0.15432904 0.2868926
65 0.112046885 0.1417486 1 0.112046885 0.15637411 0.2921039
66 0.141569684 0.1417486 1 0.141569684 0.15793097 0.2979247
67 0.171092482 0.1417486 1 0.171092482 0.15952338 0.3022751
68 0.200615281 0.1417486 1 0.200615281 0.15997047 0.3048768
69 0.230138080 0.1417486 1 0.230138080 0.16327957 0.3087545
70 0.259660879 0.1417486 1 0.259660879 0.16372900 0.3125599
71 0.289183678 0.1417486 1 0.289183678 0.16395417 0.3185642
72 0.318706477 0.1417486 1 0.318706477 0.16414444 0.3240570
73 0.348229276 0.1417486 1 0.348229276 0.16570600 0.3273931
74 0.377752075 0.1417486 1 0.377752075 0.16556032 0.3316680
75 0.407274873 0.1417486 1 0.407274873 0.16815162 0.3391713
76 0.436797672 0.1417486 1 0.436797672 0.16817144 0.3465403
77 0.466320471 0.1417486 1 0.466320471 0.16790241 0.3514764
78 0.495843270 0.1417486 1 0.495843270 0.16941330 0.3590708
79 0.525366069 0.1417486 1 0.525366069 0.17068468 0.3662851
80 0.554888868 0.1417486 1 0.554888868 0.17238535 0.3738123
81 0.584411667 0.1417486 1 0.584411667 0.17358253 0.3796033
82 0.613934466 0.1417486 1 0.613934466 0.17521059 0.3869863
83 0.643457264 0.1417486 1 0.643457264 0.17617046 0.3939509
84 0.672980063 0.1417486 1 0.672980063 0.17710931 0.3967577
85 0.702502862 0.1417486 1 0.702502862 0.17816611 0.4026686
86 0.732025661 0.1417486 1 0.732025661 0.17998354 0.4094216
87 0.761548460 0.1417486 1 0.761548460 0.18085939 0.4165644
88 0.791071259 0.1417486 1 0.791071259 0.18114271 0.4198687
89 0.820594058 0.1417486 1 0.820594058 0.18294576 0.4255245
90 0.850116857 0.1417486 1 0.850116857 0.18446785 0.4333511
91 0.879639655 0.1417486 1 0.879639655 0.18498697 0.4407155
92 0.909162454 0.1417486 1 0.909162454 0.18729221 0.4472631
93 0.938685253 0.1417486 1 0.938685253 0.18952720 0.4529227
94 0.968208052 0.1417486 1 0.968208052 0.19203126 0.4579841
95 0.997730851 0.1417486 1 0.997730851 0.19408999 0.4671136
96 1.027253650 0.1417486 1 1.027253650 0.19551024 0.4751111
97 1.056776449 0.1417486 1 1.056776449 0.19700981 0.4804208
98 1.086299247 0.1417486 1 1.086299247 0.19756573 0.4850098
99 1.115822046 0.1417486 1 1.115822046 0.20044626 0.4915511
100 1.145344845 0.1417486 1 1.145344845 0.20250046 0.4996890
lower__ upper__
1 -1.0567858 1.1982199
2 -1.0438136 1.1831539
3 -1.0228641 1.1707170
4 -1.0072313 1.1596104
5 -0.9864567 1.1438521
6 -0.9689320 1.1282532
7 -0.9505741 1.1173943
8 -0.9357609 1.0983966
9 -0.9230198 1.0859565
10 -0.9104617 1.0757511
11 -0.8874429 1.0631791
12 -0.8687644 1.0467475
13 -0.8513190 1.0348922
14 -0.8290140 1.0236083
15 -0.8126063 1.0166800
16 -0.7975146 1.0011153
17 -0.7869631 0.9873863
18 -0.7760327 0.9721754
19 -0.7551183 0.9585837
20 -0.7427828 0.9479480
21 -0.7269582 0.9405559
22 -0.7072756 0.9284436
23 -0.6975987 0.9161489
24 -0.6884648 0.9040642
25 -0.6684576 0.8923201
26 -0.6535668 0.8811996
27 -0.6517693 0.8714208
28 -0.6394743 0.8652541
29 -0.6235719 0.8542377
30 -0.6127188 0.8433206
31 -0.6017256 0.8346912
32 -0.5845027 0.8192662
33 -0.5701008 0.8098853
34 -0.5596900 0.7982326
35 -0.5473666 0.7980605
36 -0.5340069 0.7908127
37 -0.5239994 0.7826979
38 -0.5124559 0.7811926
39 -0.4986325 0.7786670
40 -0.5044564 0.7745791
41 -0.4940340 0.7699341
42 -0.4871297 0.7698303
43 -0.4808839 0.7678166
44 -0.4790951 0.7662335
45 -0.4711604 0.7576184
46 -0.4690302 0.7577330
47 -0.4675442 0.7567887
48 -0.4673520 0.7554134
49 -0.4649256 0.7499373
50 -0.4600178 0.7494690
51 -0.4500426 0.7500552
52 -0.4475863 0.7505488
53 -0.4437339 0.7513191
54 -0.4429276 0.7564214
55 -0.4427087 0.7578937
56 -0.4451014 0.7613821
57 -0.4418548 0.7706546
58 -0.4377409 0.7787030
59 -0.4397108 0.7882644
60 -0.4462651 0.8026011
61 -0.4538979 0.8069187
62 -0.4542826 0.8163290
63 -0.4557042 0.8285206
64 -0.4572005 0.8335650
65 -0.4638491 0.8413812
66 -0.4681885 0.8539095
67 -0.4775714 0.8633141
68 -0.4888333 0.8698490
69 -0.4952363 0.8791527
70 -0.4975383 0.8833882
71 -0.5088667 0.8863114
72 -0.5197474 0.8951534
73 -0.5316745 0.9085101
74 -0.5409388 0.9207023
75 -0.5572803 0.9282691
76 -0.5643576 0.9357900
77 -0.5751774 0.9517092
78 -0.5855919 0.9625510
79 -0.5995727 0.9781417
80 -0.6115650 0.9946185
81 -0.6198287 1.0071916
82 -0.6297608 1.0208370
83 -0.6447637 1.0357034
84 -0.6511860 1.0506364
85 -0.6659993 1.0608813
86 -0.6794852 1.0702993
87 -0.6893830 1.0801824
88 -0.7040491 1.1026626
89 -0.7183266 1.1196308
90 -0.7387399 1.1401544
91 -0.7541057 1.1561184
92 -0.7608552 1.1701851
93 -0.7783620 1.1855296
94 -0.7920760 1.2014060
95 -0.8063188 1.2157463
96 -0.8224106 1.2307841
97 -0.8377605 1.2484814
98 -0.8530954 1.2580503
99 -0.8684646 1.2731355
100 -0.8840083 1.2891893
Where I can easily plot the estimate__
vs x
column to obtain my linear regression.
Now assuming I want to do the same but with my own STAN code using the stan()
function:
library(rstan)
N <- 10
y <- rnorm(10)
x <- rnorm(10)
df <- data.frame(x, y)
fit <- stan('stan_test.stan', data = list(y = y, x = x, N = N))
print(fit)
Which yields the output:
Inference for Stan model: stan_test.
4 chains, each with iter=2000; warmup=1000; thin=1;
post-warmup draws per chain=1000, total post-warmup draws=4000.
mean se_mean sd 2.5% 25% 50% 75% 97.5% n_eff Rhat
alpha -0.35 0.01 0.43 -1.23 -0.62 -0.35 -0.09 0.50 2185 1
beta -0.26 0.01 0.57 -1.41 -0.60 -0.25 0.08 0.86 2075 1
sigma 1.26 0.01 0.41 0.74 0.99 1.17 1.43 2.27 1824 1
lp__ -6.19 0.04 1.50 -10.18 -6.87 -5.79 -5.07 -4.48 1282 1
Samples were drawn using NUTS(diag_e) at Fri Jun 03 10:08:50 2022.
For each parameter, n_eff is a crude measure of effective sample size,
and Rhat is the potential scale reduction factor on split chains (at
convergence, Rhat=1).
How would I obtain the same estimate__
column as well as the lower__
and upper__
columns?
Note, I know I can easily plot it using the intercept and slope means, but I would like to plot more complex models that can't be plotted as easily as such -- this is just a simple example.
My understanding is that brms
estimates conditional effects by applying the model formula to a range of values for the variable you're interested in, with other variables set to appropriate baseline values. In order to do this, brms
has to generate the new dataset, apply the model to it, and summarize appropriately. To my knowledge, rstan
doesn't have built-in functions that do this; this means that, when we move from brms
to rstan
, we have to do these steps ourselves.
Here's one way to do it. I've done the first two steps (generate a new dataset and apply the model to it) within Stan, although it would be possible to use R instead.
I've added a transformed data
block to the basic Stan program. It finds the min and max observed values of x
and creates a vector of 100 evenly spaced points between those two values. If you have more than one predictor for which you want to estimate conditional effects, you'll need to create a separate vector for each one.
data {
int<lower=0> N;
vector[N] x;
vector[N] y;
}
transformed data {
// How many values of the continuous variable will we use to estimate
// conditional effects?
// 100, to match the default behavior of conditional_effects.
int n_cond_points = 100;
vector[n_cond_points] x_cond_internal;
// Space the values evenly between the min and max observed values.
real point_diff = (max(x) - min(x)) / n_cond_points;
for(i in 1:n_cond_points) {
if(i == 1) {
x_cond_internal[i] = min(x);
} else if(i == n_cond_points) {
x_cond_internal[i] = max(x);
} else {
x_cond_internal[i] = x_cond_internal[i - 1] + point_diff;
}
}
}
I used the generated quantities
block to apply the model to the new dataset. Three things are noteworthy here:
transformed data
block out of a stanfit
object. As suggested by this answer, I've copied the new dataset into a variable in the generated quantities
block so we can get it out of the stanfit
object. (It will be the same across all draws.)model
block. If the model changes there, it must be changed by hand in the same way in the generated quantities
block.parameters {
real alpha;
real beta;
real<lower=0> sigma;
}
model {
y ~ normal(alpha + (beta * x), sigma);
}
generated quantities {
// We can't extract transformed data from the stanfit object, so we copy the
// values of x_cond here.
vector[n_cond_points] x_cond = x_cond_internal;
// Estimated value of y for each value of x.
// Note that we have to specify the formula from the model block again; if
// that formula changes, this one must be changed by hand to match.
vector[n_cond_points] y_cond;
for(i in 1:n_cond_points) {
y_cond[i] = alpha + (beta * x_cond[i]);
}
}
When we fit this Stan model, we get one estimate of y_cond
per value of x_cond
per draw, which is exactly what we want. We can summarize over draws in R:
library(tidyverse)
library(tidybayes)
fit2 <- stan('stan_test.stan', data = list(y = y, x = x, N = N))
cond.effects.df = spread_draws(fit2, x_cond[i], y_cond[i]) %>%
ungroup() %>%
dplyr::select(.draw, i, x = x_cond, y_cond) %>%
group_by(i, x) %>%
summarise(estimate__ = median(y_cond),
lower__ = quantile(y_cond, 0.025),
upper__ = quantile(y_cond, 0.975),
.groups = "keep") %>%
ungroup()
The results of this procedure look pretty much the same as the output of brms
. Here's what I got:
theme_set(theme_bw())
bind_rows(
data[["x"]] %>%
mutate(i = row_number(),
method = "brms"),
cond.effects.df %>%
mutate(method = "by hand")
) %>%
ggplot(aes(x = x, color = method, fill = method, group = method)) +
geom_line(aes(y = estimate__)) +
geom_ribbon(aes(ymin = lower__, ymax = upper__), color = NA, alpha = 0.2)