I'm trying to learn a bit about pyro and building probabilistic neural networks with pytorch. Normally, with a pytorch.nn.Module I can move it to the GPU with model.to( 'cuda')
however this does not seem to work with a pyro Module. How does one correctly place a pyro Module model onto the GPU?
Example Model:
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
import torch.nn as nn
from pyro.infer.autoguide import AutoDiagonalNormal
from pyro.infer import SVI, Trace_ELBO, Predictive
class Model(PyroModule):
def __init__(self, h1=20, h2=20):
super().__init__()
self.fc1 = PyroModule[nn.Linear](1, h1)
self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, 1]).to_event(2))
self.fc1.bias = PyroSample(dist.Normal(0., 1.).expand([h1]).to_event(1))
self.fc2 = PyroModule[nn.Linear](h1, h2)
self.fc2.weight = PyroSample(dist.Normal(0., 1.).expand([h2, h1]).to_event(2))
self.fc2.bias = PyroSample(dist.Normal(0., 1.).expand([h2]).to_event(1))
self.fc3 = PyroModule[nn.Linear](h2, 1)
self.fc3.weight = PyroSample(dist.Normal(0., 1.).expand([1, h2]).to_event(2))
self.fc3.bias = PyroSample(dist.Normal(0., 1.).expand([1]).to_event(1))
self.relu = nn.ReLU()
def forward(self, x, y=None):
x = x.reshape(-1, 1)
x = self.relu(self.fc1(x))
x = self.relu(self.fc2(x))
mu = self.fc3(x).squeeze()
sigma = pyro.sample("sigma", dist.Uniform(0., 1.))
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mu, sigma), obs=y)
return mu
then:
model = Model()
however model.to( 'cuda')
does not seem to actually move the model to the GPU.
Update: I'm not sure if this is a correct solution...
I find that if I replace PyroSample
with pyro.nn.PyroParam
then they are listed in the ParamDict and can be moved to the gpu.
I've run into a similar problem. Even though PyroModule
is subclassed from nn.Module
, as far as I can tell, the .to
method does not work to carry over PyroSample
objects the same way one might expect for nn.Parameter
objects. (I think it might work for PyroParam
, like you say).
I found this Pyro forum post to have a workable solution. It says you can initialize the PyroSample(dist...)
calls with tensors that are already on the GPU. For example, change lines like your
self.fc1.weight = PyroSample(dist.Normal(0., 1.).expand([h1, 1]).to_event(2))
to
self.fc1.weight = PyroSample(dist.Normal(torch.tensor(0., device="cuda"), 1.).expand([h1, 1]).to_event(2))
I've found this is only a problem when you have multivariate PyroSample
objects (i.e., you have the .expand..to_event..
here). If it's a univariate object then the variable transfers from the CPU to GPU without complaining.