pythonpymc3arviz

Old PyMC3 style grouping traceplot plotted with Arviz


I have an old blogpost where I am training a PyMC3 model. You can find the blogpost here but the gist of the model is shown below.

with pm.Model() as model:
    mu_intercept = pm.Normal('mu_intercept', mu=40, sd=5)
    mu_slope = pm.HalfNormal('mu_slope', 10, shape=(n_diets,))
    mu = mu_intercept + mu_slope[df.diet-1] * df.time
    sigma_intercept = pm.HalfNormal('sigma_intercept', sd=2)
    sigma_slope = pm.HalfNormal('sigma_slope', sd=2, shape=n_diets)
    sigma = sigma_intercept + sigma_slope[df.diet-1] * df.time
    weight = pm.Normal('weight', mu=mu, sd=sigma, observed=df.weight)
    approx = pm.fit(20000, random_seed=42, method="fullrank_advi")

In this dataset I'm estimating the effect of Diet on the weight of chickens. This is what the traceplot looks like.

traceplot

Look at how pretty it is! Each diet has its own line! Beautiful!

Arviz Changes

This traceplot was made using the older PyMC3 API. Nowadays this functionality has moved to arviz. So tried redo-ing this work but ... the plot looks very different.

enter image description here

The code that I'm using here is slightly different. I'm using pm.Data now but I doubt that's supposed to cause this difference.

with pm.Model() as mod: 
    time_in = pm.Data("time_in", df['time'].astype(float))
    diet_in = pm.Data("diet_in", dummies)
    
    intercept = pm.Normal("intercept", 0, 2)
    time_effect = pm.Normal("time_weight_effect", 0, 2, shape=(4,))
    diet = pm.Categorical("diet", p=[0.25, 0.25, 0.25, 0.25], shape=(4,), observed=diet_in)
    sigma = pm.HalfNormal("sigma", 2)
    sigma_time_effect = pm.HalfNormal("time_sigma_effect", 2, shape=(4,))
    weight = pm.Normal("weight", 
                       mu=intercept + time_effect.dot(diet_in.T)*time_in, 
                       sd=sigma + sigma_time_effect.dot(diet_in.T)*time_in, 
                       observed=df.weight)
    trace = pm.sample(5000, return_inferencedata=True)

What do I need to do to get the different colors per DIET back in?


Solution

  • There's a parameter for it in the new plot_trace function. This does the trick;

    az.plot_trace(trace, compact=True)