I am trying to understand how PyTorch keeps track of overwritten tensor values for autodifferentiation.
In the code below, w
is overwritten before autodifferentiation, yet PyTorch remembers the original value of w
when executing z.backward()
.
import torch
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
y = w * x
w = torch.tensor([4.0], requires_grad=True)
z = y ** 2
z.backward()
print(x.grad)
I expected that x.grad would evaluate to 2*w**2*x=2*16*1=32, given that w=4 at the time of calling z.backward()
. However, The result is x.grad=2*4*1=8, where PyTorch "remembers" that w
was equal to 2 at the time when y
was defined.
How does PyTorch do this?
Assigning a new value to w
changes that reference to the underlying tensor. The tensor itself does not disappear (as a reference to that tensor is still kept by autograd). You can learn more about variable assignment here: https://realpython.com/python-variables/#variable-assignment
An object’s life begins when it is created, at which time at least one reference to it is created. During an object’s lifetime, additional references to it may be created, as you saw above, and references to it may be deleted as well. An object stays alive, as it were, so long as there is at least one reference to it. When the number of references to an object drops to zero, it is no longer accessible. At that point, its lifetime is over. Python will eventually notice that it is inaccessible and reclaim the allocated memory so it can be used for something else. In computer lingo, this process is referred to as garbage collection.
In contrast, if you perform an in-place operation to the tensor directly, an error will be thrown (if w
itself does not require gradient, an error will be thrown during backward):
import torch
x = torch.tensor([1.0], requires_grad=True)
w = torch.tensor([2.0], requires_grad=True)
y = w * x
w.fill_(4.0)
RuntimeError: a leaf Variable that requires grad is being used in an in-place operation.