pythonpytorch

How to update a leaf variable in PyTorch


I am trying to implement simple gradient descent to find the root of a quadratic equation using PyTorch. I'm doing this to get a better sense of how the autograd function works but it's not going very well. Let's say that I want to find the roots of

y = 3x^2 + 4x + 9

as a random example. Below was my first attempt to run one step of gradient descent and re-calculate the gradient:

import torch

# Step size
alpha = 0.1

# Random starting point
x = torch.tensor([42.0], requires_grad=True)

# Function y = 3 * x^2 + 4 * x + 9
# Find the minimum of this with gradient descent
y = 3 * x ** 2 + 4 * x + 9
y.backward()
print(x.grad)

with torch.no_grad():
    x -= alpha * x.grad

y.backward()
print(x.grad)

This didn't like calling .backward() multiple times, so I updated it to this:

import torch

# Step size
alpha = 0.1

# Random starting point
x = torch.tensor([42.0], requires_grad=True)

# Function y = 3 * x^2 + 4 * x + 9
# Find the minimum of this with gradient descent
y = 3 * x ** 2 + 4 * x + 9
y.backward(retain_graph=True) # <--- Change here
print(x.grad)

with torch.no_grad():
    x -= alpha * x.grad

y.backward(retain_graph=True) # <--- And here
print(x.grad)

and I get the error:

RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation: [torch.FloatTensor [1]] is at version 1; expected version 0 instead. Hint: enable anomaly detection to find the operation that failed to compute its gradient, with torch.autograd.set_detect_anomaly(True).

I get the sense that I am fundamentally misunderstanding something about PyTorch here. How would I do this correctly?


Solution

  • In-place operations like x -= ... can break the computation graph if the tensor is a leaf tensor that requires gradients, and the operation is not inside a torch.no_grad() context. This causes a version mismatch error during .backward() if you try to reuse the computation graph or modify variables that are part of it.

    1. Gradients in PyTorch accumulate by default for efficiency (useful during mini-batch training). So you must call x.grad.zero_() (or reinitialize x with .detach()) before the next .backward() pass if you're doing manual gradient descent.

    2. Updating model parameters (or any tensor that requires gradients) must happen inside torch.no_grad() to prevent PyTorch from tracking the update operation itself. If you don't, the update becomes part of the graph and causes unwanted memory usage and errors.

    3. retain_graph=True is only needed if you plan to reuse the same computation graph across multiple backward passes (e.g., higher-order derivatives). In simple gradient descent (with one .backward() per step), there's no need for retain_graph=True.

    Here is the code with correction:

    import torch
    
    alpha = 0.1
    x = torch.tensor([42.0], requires_grad=True)
    
    for i in range(10):
        y = 3 * x ** 2 + 4 * x + 9
        y.backward()
        print(f"Step {i+1}: x = {x.item():.4f}, y = {y.item():.4f}, grad = {x.grad.item():.4f}")
        with torch.no_grad():
            x -= alpha * x.grad
        x.grad.zero_()
    # you may replace  x.grad.zero_() with x = x.detach().clone().requires_grad_(True) for advanced control.