I tried to define a simple model in Pytorch. The model computes negative log prob for a gaussian distribution:
import torch
import torch.nn as nn
class GaussianModel(nn.Module):
def __init__(self):
super(GaussianModel, self).__init__()
self.register_parameter('mean', nn.Parameter(torch.zeros(1),
requires_grad=True))
self.pdf = torch.distributions.Normal(self.state_dict()['mean'],
torch.tensor([1.0]))
def forward(self, x):
return -self.pdf.log_prob(x)
model = GaussianModel()
Then I tried to optimize the mean
parameter:
optimizer = torch.optim.SGD(model.parameters(), lr=0.002)
for _ in range(5):
optimizer.zero_grad()
nll = model(torch.tensor([3.0], requires_grad=True))
nll.backward()
optimizer.step()
print('mean : ', model.state_dict()['mean'],
' - Negative Loglikelihood : ', nll.item())
But it seems the gradient is zero and mean
does not change:
mean : tensor([0.]) - Negative Loglikelihood : 5.418938636779785
mean : tensor([0.]) - Negative Loglikelihood : 5.418938636779785
mean : tensor([0.]) - Negative Loglikelihood : 5.418938636779785
mean : tensor([0.]) - Negative Loglikelihood : 5.418938636779785
mean : tensor([0.]) - Negative Loglikelihood : 5.418938636779785
Did I register and use the mean
parameter correctly? can autograd compute the gradient for torch.distributions.Normal.log_prob
or I should implement the backward()
for the model?
You're over complicating registering your parameter. You can just assign a new self.mean
attribute to be an nn.Parameter
then use it like a tensor for the most part.
nn.Module
overrides the __setattr__
method which is called every time you assign a new class attribute. One of the things it does is check to see if you assigned an nn.Parameter
type, and if so, it adds it to the modules dictionary of registered parameters.
Because of this, the easiest way to register your parameter is as follows:
import torch
import torch.nn as nn
class GaussianModel(nn.Module):
def __init__(self):
super(GaussianModel, self).__init__()
self.mean = nn.Parameter(torch.zeros(1))
self.pdf = torch.distributions.Normal(self.mean, torch.tensor([1.0]))
def forward(self, x):
return -self.pdf.log_prob(x)