pythonpytorchpyro.ai

Python PyTorch Pyro - Multivariate Distributions


How does one sample a multivariate distribution in Pyro? I just want a (M, N) Beta distribution, but the following doesn't work:

impor torch
import pyro
with pyro.plate("theta_plate", M):
    theta = pyro.sample("theta",
                        pyro.distributions.Beta(concentration0=torch.ones(N),
                                                concentration1=torch.ones(N)))


Solution

  • Use to_event(n) to declare depdent samples.

    import torch
    import pyro
    import pyro.distributions as dist
    
    def model(N, M):
        with pyro.plate("theta_plate", M):
            theta = pyro.sample("theta", dist.Beta(torch.ones(N),1.).to_event(1))
        return theta
    
    
    if __name__ == '__main__':
        print(model(10,12).shape) # (10,12)