pytorchautograd

Forward pass only training with a custom step


I'm attempting to implement a custom single forward pass training algorithm using PyTorch. Since I don't require backpropagation, I am manually updating the weights of the neural network. However, I can't seem to get this to work correctly. After the first pass, I repeatedly get the error that I'm trying to backward through the computational graph for a second time, despite having zeroed out the gradients in the model. Not sure where I'm going wrong.

class OneD_NN_LQR:
    # Initially we will ignore batch size
    def __init__(self, hidden_units, learning_rate_param_C=0.05, batch_size=100):
        # Single layer neural network for the control, of the form f(x) = sum(c_i * g(w_i * x + b_i))
        # We will use a sigmoid activation function (i.e. g = sigmoid)
        self.C = learning_rate_param_C
        self.N = batch_size
        self.hidden_units = hidden_units
        self.dim = 1

        self.layer1 = torch.nn.Linear(in_features=self.dim, out_features=self.hidden_units)
        self.activation = torch.nn.ReLU()
        self.layer2 = torch.nn.Linear(in_features=self.hidden_units, out_features=self.dim, bias=False)
        self.model = torch.nn.Sequential(
            self.layer1,
            self.activation,
            self.layer2
        )
        self.w = self.layer1.weight
        self.b = self.layer1.bias
        self.c = self.layer2.weight

        self.Xtilde_w = torch.zeros((self.hidden_units,)).unsqueeze(1)
        self.Xtilde_c = torch.zeros((self.hidden_units,)).unsqueeze(1)
        self.Xtilde_b = torch.zeros((self.hidden_units,))

        self.X = torch.ones((self.dim,), requires_grad=True)
        self.f_x = self.forward(self.X)

        #self.grads = torch.autograd.grad(self.f_x, inputs=[self.X, self.layer2.weight, self.layer1.bias, self.layer1.weight])
        self.f_x.backward()
        # self.grad_x = self.grads[0]
        # self.grad_c = self.grads[1].T
        # self.grad_b = self.grads[2]
        # self.grad_w = self.grads[3]
        self.grad_x = self.X.grad
        self.grad_c = self.c.grad.T
        self.grad_b = self.b.grad
        self.grad_w = self.w.grad
        
        self.time = 0
        
    def step(self, delta):
        # Stepping also involves updating the values of f(x) and f'(x)
        self.time += delta

        self.step_X(delta)
        self.step_Xtilde(delta)
        self.step_theta(delta)
        
        self.model.zero_grad()

        self.f_x = self.model.forward(self.X)

        print(self.f_x)
        #self.grads = torch.autograd.grad(self.f_x, inputs=[self.X, self.layer2.weight, self.layer1.bias, self.layer1.weight])
        self.f_x.backward()
        # self.grad_x = self.grads[0]
        # self.grad_c = self.grads[1].T
        # self.grad_b = self.grads[2]
        # self.grad_w = self.grads[3]
        self.grad_x = self.X.grad
        self.grad_c = self.c.grad.T
        self.grad_b = self.b.grad
        self.grad_w = self.w.grad
        
        return self.w, self.c, self.b

    def step_theta(self, delta):
        next_dw, next_dc, next_db = self.next_dtheta(delta)

        with torch.no_grad():
            self.layer1.weight.sub_(next_dw)
            self.layer1.bias.sub_(next_db)
            self.layer2.weight.sub_(next_dc.T)
            self.model.zero_grad()

    def step_X(self, delta):
        next_dX = self.next_dX(delta)
        self.X = self.X + next_dX

    def step_Xtilde(self, delta):
        next_dXtilde_w, next_dXtilde_c, next_dXtilde_b = self.next_dXtilde(delta)

        self.Xtilde_w = self.Xtilde_w + next_dXtilde_w
        self.Xtilde_c = self.Xtilde_c + next_dXtilde_c
        self.Xtilde_b = self.Xtilde_b + next_dXtilde_b
        

    def next_dtheta(self, delta):
        alpha = self.get_learning_rate(self.C, self.time)

        dw = alpha * (2 * self.X * self.Xtilde_w + 2 * self.f_x * (self.grad_w + self.grad_x * self.Xtilde_w)) * delta
        db = alpha * (2 * self.X * self.Xtilde_b + 2 * self.f_x * (self.grad_b + self.grad_x * self.Xtilde_b)) * delta
        dc = alpha * (2 * self.X * self.Xtilde_w + 2 * self.f_x * (self.grad_c + self.grad_x * self.Xtilde_c)) * delta
        
        return dw, dc, db

    def get_learning_rate(self, c, time):
        if time > 500: return c / 10
        if time > 100: return c / 5
        if time > 50: return c / 2
        return c

    def next_dXtilde(self, delta):
        
        dXtilde_w = (- self.Xtilde_w + self.grad_w + self.grad_x * self.Xtilde_w) * delta
        dXtilde_b = (- self.Xtilde_b + self.grad_b + self.grad_x * self.Xtilde_b) * delta
        dXtilde_c = (- self.Xtilde_c + self.grad_c + self.grad_x * self.Xtilde_c) * delta
        
        return dXtilde_w, dXtilde_c, dXtilde_b

    def next_dX(self, delta):
        to_return = (-self.X + self.f_x) * delta + torch.normal(0, 1, size=(self.dim,)) * (delta ** 0.5)
        return to_return

    def forward(self, x):
        #to_return = torch.unsqueeze(torch.sum(self.c * self.activation(self.w * x + self.b), axis=1), 1)
        to_return = self.model.forward(x)
        return to_return

My training loop is as follows:

x = torch.tensor([5]).unsqueeze(1)
y = []
step_size = 1e-2
theta_vals = []
range_end = 10
fwd_propagator = OneD_NN_LQR(16, learning_rate_param_C=100, batch_size=10)
for i in np.arange(0, range_end, step_size):
    theta = fwd_propagator.step(step_size)
    theta_vals.append(theta)
    y.append(fwd_propagator.forward(x)[0])

Solution

  • The problem is with the self.f_x = self.forward(self.X) in the step function. It seem like self.X is being recognized by torch for being a part of computational graph (because it require grad) so after the first .backward() at the __init__, it has been freed by torch so use it to compute again would raise an error. You can change the forward pass in step to be:

    ...
    x = self.X.detach().clone().requires_grad_()
    self.f_x = self.forward(x)
    
    print(self.f_x)
    # self.grads = torch.autograd.grad(self.f_x, inputs=[self.X, self.layer2.weight, self.layer1.bias, self.layer1.weight])
    self.f_x.backward()
    # self.grad_x = self.grads[0]
    # self.grad_c = self.grads[1].T
    # self.grad_b = self.grads[2]
    # self.grad_w = self.grads[3]
    self.grad_x = x.grad
    ...