python-xarraypymc3arviz

PyMC3/Arviz: CDF value from trace


I have a sample from PyMC3 and I'm trying to get a cumulative probability from it, e.g. P(X < 0). I currently use this:

trace = pymc3.sample(return_inferencedata=True)
prob_x_lt_zero = (trace.posterior.X < 0).sum() / trace.posterior.X.size

Is there a better way to do this, either with some helper function from Arviz or XArray?

I haven't found any .cdf() method or anything similar. It's weird that such basic functions are missing, but more advanced ones are there, such as trace.posterior.X.quantile().


Solution

  • I would recommend your original approach evaluating the condition and averaging (that is basically using the empirical cdf) instead of using the KDE.

    There is no equivalent that I know of, probably also because there is no equivalent in numpy either (which has both quantile and percentile). There is one in scipy: scipy.stats.percentileofscore but I wouldn't recommend it either unless you are working with discrete data and need the kind argument to evaluate ties (i.e. would you care or notice any difference between using < or <=?). This scipy function also takes only a scalar as value to evaluate the ecdf against.

    My recommendation therefore is to stick with your method, but modify a bit the implementation, so it also works when evaluating multiple values at the same time and when not reducing all the dimensions:

    import arviz; import xarray
    x = xarray.DataArray([-.1, 0, .1])  # skip that if working with scalars
    post = arviz.load_arviz_data("rugby").posterior
    prob_x_lt_zero = (post.atts < x).mean(("chain", "draw"))
    

    which returns the probabilities for each of the 3 values we are evaluating at all 6 teams.

    <xarray.DataArray (team: 6, dim_0: 3)>
    array([[0.    , 0.    , 0.0485],
           [0.347 , 0.975 , 1.    ],
           [0.    , 0.004 , 0.4245],
           [0.64  , 0.994 , 1.    ],
           [1.    , 1.    , 1.    ],
           [0.    , 0.    , 0.    ]])
    Coordinates:
      * team     (team) object 'Wales' 'France' 'Ireland' ... 'Italy' 'England'
    Dimensions without coordinates: dim_0