how to set device ('cuda', 'cpu') for pyro PyroModule

I'm trying to learn a bit about pyro and building probabilistic neural networks with pytorch. Normally, with a pytorch.nn.Module I can move it to the GPU with 'cuda') however this does not seem to work with a pyro Module. How does one correctly place a pyro Module model onto the GPU?

Example Model:

import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
import torch.nn as nn
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO, Predictive

class Model(PyroModule):
    def __init__(self, h1=20, h2=20):
        self.fc1 = PyroModule[nn.Linear](1, h1)
        self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, 1]).to_event(2))
        self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([h1]).to_event(1))
        self.fc2 = PyroModule[nn.Linear](h1, h2)
        self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([h2, h1]).to_event(2))
        self.fc2.bias = PyroSample(dist.Normal(0., 1.).expand([h2]).to_event(1))
        self.fc3 = PyroModule[nn.Linear](h2, 1)
        self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([1, h2]).to_event(2))
        self.fc3.bias = PyroSample(dist.Normal(0., 1.).expand([1]).to_event(1))
        self.relu = nn.ReLU()

    def forward(self, x, y=None):
        x = x.reshape(-1, 1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        mu = self.fc3(x).squeeze()
        sigma = pyro.sample("sigma", dist.Uniform(0., 1.))
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
        return mu

then: model = Model() however 'cuda') does not seem to actually move the model to the GPU.

Update: I'm not sure if this is a correct solution... I find that if I replace PyroSample with pyro.nn.PyroParam then they are listed in the ParamDict and can be moved to the gpu.


  • I've run into a similar problem. Even though PyroModule is subclassed from nn.Module, as far as I can tell, the .to method does not work to carry over PyroSample objects the same way one might expect for nn.Parameter objects. (I think it might work for PyroParam, like you say).

    I found this Pyro forum post to have a workable solution. It says you can initialize the PyroSample(dist...) calls with tensors that are already on the GPU. For example, change lines like your

    self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, 1]).to_event(2))


    self.fc1.weight = PyroSample(dist.Normal(torch.tensor(0., device="cuda"), 1.).expand([h1, 1]).to_event(2))

    I've found this is only a problem when you have multivariate PyroSample objects (i.e., you have the .expand..to_event.. here). If it's a univariate object then the variable transfers from the CPU to GPU without complaining.