I'm attempting to implement a custom single forward pass training algorithm using PyTorch. Since I don't require backpropagation, I am manually updating the weights of the neural network. However, I can't seem to get this to work correctly. After the first pass, I repeatedly get the error that I'm trying to backward through the computational graph for a second time, despite having zeroed out the gradients in the model. Not sure where I'm going wrong.
class OneD_NN_LQR:
# Initially we will ignore batch size
def __init__(self, hidden_units, learning_rate_param_C=0.05, batch_size=100):
# Single layer neural network for the control, of the form f(x) = sum(c_i * g(w_i * x + b_i))
# We will use a sigmoid activation function (i.e. g = sigmoid)
self.C = learning_rate_param_C
self.N = batch_size
self.hidden_units = hidden_units
self.dim = 1
self.layer1 = torch.nn.Linear(in_features=self.dim, out_features=self.hidden_units)
self.activation = torch.nn.ReLU()
self.layer2 = torch.nn.Linear(in_features=self.hidden_units, out_features=self.dim, bias=False)
self.model = torch.nn.Sequential(
self.layer1,
self.activation,
self.layer2
)
self.w = self.layer1.weight
self.b = self.layer1.bias
self.c = self.layer2.weight
self.Xtilde_w = torch.zeros((self.hidden_units,)).unsqueeze(1)
self.Xtilde_c = torch.zeros((self.hidden_units,)).unsqueeze(1)
self.Xtilde_b = torch.zeros((self.hidden_units,))
self.X = torch.ones((self.dim,), requires_grad=True)
self.f_x = self.forward(self.X)
#self.grads = torch.autograd.grad(self.f_x, inputs=[self.X, self.layer2.weight, self.layer1.bias, self.layer1.weight])
self.f_x.backward()
# self.grad_x = self.grads[0]
# self.grad_c = self.grads[1].T
# self.grad_b = self.grads[2]
# self.grad_w = self.grads[3]
self.grad_x = self.X.grad
self.grad_c = self.c.grad.T
self.grad_b = self.b.grad
self.grad_w = self.w.grad
self.time = 0
def step(self, delta):
# Stepping also involves updating the values of f(x) and f'(x)
self.time += delta
self.step_X(delta)
self.step_Xtilde(delta)
self.step_theta(delta)
self.model.zero_grad()
self.f_x = self.model.forward(self.X)
print(self.f_x)
#self.grads = torch.autograd.grad(self.f_x, inputs=[self.X, self.layer2.weight, self.layer1.bias, self.layer1.weight])
self.f_x.backward()
# self.grad_x = self.grads[0]
# self.grad_c = self.grads[1].T
# self.grad_b = self.grads[2]
# self.grad_w = self.grads[3]
self.grad_x = self.X.grad
self.grad_c = self.c.grad.T
self.grad_b = self.b.grad
self.grad_w = self.w.grad
return self.w, self.c, self.b
def step_theta(self, delta):
next_dw, next_dc, next_db = self.next_dtheta(delta)
with torch.no_grad():
self.layer1.weight.sub_(next_dw)
self.layer1.bias.sub_(next_db)
self.layer2.weight.sub_(next_dc.T)
self.model.zero_grad()
def step_X(self, delta):
next_dX = self.next_dX(delta)
self.X = self.X + next_dX
def step_Xtilde(self, delta):
next_dXtilde_w, next_dXtilde_c, next_dXtilde_b = self.next_dXtilde(delta)
self.Xtilde_w = self.Xtilde_w + next_dXtilde_w
self.Xtilde_c = self.Xtilde_c + next_dXtilde_c
self.Xtilde_b = self.Xtilde_b + next_dXtilde_b
def next_dtheta(self, delta):
alpha = self.get_learning_rate(self.C, self.time)
dw = alpha * (2 * self.X * self.Xtilde_w + 2 * self.f_x * (self.grad_w + self.grad_x * self.Xtilde_w)) * delta
db = alpha * (2 * self.X * self.Xtilde_b + 2 * self.f_x * (self.grad_b + self.grad_x * self.Xtilde_b)) * delta
dc = alpha * (2 * self.X * self.Xtilde_w + 2 * self.f_x * (self.grad_c + self.grad_x * self.Xtilde_c)) * delta
return dw, dc, db
def get_learning_rate(self, c, time):
if time > 500: return c / 10
if time > 100: return c / 5
if time > 50: return c / 2
return c
def next_dXtilde(self, delta):
dXtilde_w = (- self.Xtilde_w + self.grad_w + self.grad_x * self.Xtilde_w) * delta
dXtilde_b = (- self.Xtilde_b + self.grad_b + self.grad_x * self.Xtilde_b) * delta
dXtilde_c = (- self.Xtilde_c + self.grad_c + self.grad_x * self.Xtilde_c) * delta
return dXtilde_w, dXtilde_c, dXtilde_b
def next_dX(self, delta):
to_return = (-self.X + self.f_x) * delta + torch.normal(0, 1, size=(self.dim,)) * (delta ** 0.5)
return to_return
def forward(self, x):
#to_return = torch.unsqueeze(torch.sum(self.c * self.activation(self.w * x + self.b), axis=1), 1)
to_return = self.model.forward(x)
return to_return
My training loop is as follows:
x = torch.tensor([5]).unsqueeze(1)
y = []
step_size = 1e-2
theta_vals = []
range_end = 10
fwd_propagator = OneD_NN_LQR(16, learning_rate_param_C=100, batch_size=10)
for i in np.arange(0, range_end, step_size):
theta = fwd_propagator.step(step_size)
theta_vals.append(theta)
y.append(fwd_propagator.forward(x)[0])
The problem is with the self.f_x = self.forward(self.X)
in the step
function. It seem like self.X
is being recognized by torch
for being a part of computational graph (because it require grad) so after the first .backward()
at the __init__
, it has been freed by torch
so use it to compute again would raise an error. You can change the forward pass in step
to be:
...
x = self.X.detach().clone().requires_grad_()
self.f_x = self.forward(x)
print(self.f_x)
# self.grads = torch.autograd.grad(self.f_x, inputs=[self.X, self.layer2.weight, self.layer1.bias, self.layer1.weight])
self.f_x.backward()
# self.grad_x = self.grads[0]
# self.grad_c = self.grads[1].T
# self.grad_b = self.grads[2]
# self.grad_w = self.grads[3]
self.grad_x = x.grad
...