I'm trying to train 'trainable Bernoulli distribution' using 'pyro'.
I want to train Bernoulli distribution's parameter(probability to win) using NLL loss.
train_data is one-hot encoded sparse matrix(2034,19475) and train_labels has 4 value(4 class, [0,1,2,3]).
import torch
import pyro
pyd = pyro.distributions
print("torch version:", torch.__version__)
print("pyro version:", pyro.__version__)
import matplotlib.pyplot as plt
import numpy as np
torch.manual_seed(123)
### 0. define Negative Log Likelihood(NLL) loss function
def nll(x_train, distribution):
return -torch.mean(distribution.log_prob(torch.tensor(x_train, dtype=torch.float)))
### 1. initialize bernoulli distribution(trainable distribution)
train_vars = (pyd.Uniform(low=torch.FloatTensor([0.01]),
high=torch.FloatTensor([0.1])).rsample([train_data.shape[-1]]).squeeze())
distribution = pyd.Bernoulli(probs=train_vars)
### 2. initialize 'label 0' data
class_mask = (train_labels==0)
class_data = train_data[class_mask, :]
### 3. initialize optimizer
optim = torch.optim.Adam([train_vars])
train_vars.requires_grad=True
### 4. train loop
for i in range(0,100):
loss = nll(class_data, distribution)
loss.backward()
When I run this code, I get RUNTIME ERROR like below..
How should I deal with this error case?
your comment would be very very very appreciate.
torch version: 1.9.0+cu102
pyro version: 1.7.0
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-269-0081bb1bb843> in <module>
25 loss = nll(class_data, distribution)
26
---> 27 loss.backward()
28
/nf/yes/lib/python3.8/site-packages/torch/_tensor.py in backward(self, gradient, retain_graph, create_graph, inputs)
253 create_graph=create_graph,
254 inputs=inputs)
--> 255 torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
256
257 def register_hook(self, hook):
/nf/yes/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables, inputs)
145 retain_graph = create_graph
146
--> 147 Variable._execution_engine.run_backward(
148 tensors, grad_tensors_, retain_graph, create_graph, inputs,
149 allow_unreachable=True, accumulate_grad=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed). Saved intermediate values of the graph are freed when you call .backward() or autograd.grad(). Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.
You need to move
distribution = pyd.Bernoulli(probs=train_vars)
inside the loop, because it uses train_vars
, which requires_grad
.