This is probably a stupid question but I was going through the micrograd repo and I came across a nested function that seems confusing to me.
class Value:
""" stores a single scalar value and its gradient """
def __init__(self, data, _children=(), _op=''):
self.data = data
self.grad = 0
# internal variables used for autograd graph construction
self._backward = lambda: None
self._prev = set(_children)
self._op = _op # the op that produced this node, for graphviz / debugging / etc
def __add__(self, other):
other = other if isinstance(other, Value) else Value(other)
out = Value(self.data + other.data, (self, other), '+')
def _backward():
self.grad += out.grad
other.grad += out.grad
out._backward = _backward
return out
The _backward
function defined inside __add__
is attached to a newly created Value object. _backward
is using self and other. It is clear from the definition where it is getting those from but when we assign the _backward
of the object to refer to nested _backward
, how is it able to get the self and other objects?
The _backward function is able to get the self and other objects because it is a closure. A closure is a function that can access variables from the scope in which it was created, even after that scope has ended. In this case, the _backward function is created within the add method, so it has access to the self and other variables even after the add method has returned.
When a closure is created, Python creates a copy of the stack frame for the function in which the closure was created. This means that the closure has access to all of the variables in the stack frame, even after the function has returned.
In the case of the _backward function, it is created within the add method. When the add method returns, the stack frame for the add method is destroyed. However, the _backward function still has access to the self and other variables, because it has a copy of the stack frame for the add method.
This is why the _backward function is able to use the self and other objects, even though it is defined inside the add method. simpler example to illustrate how closures work:
def outer():
x = 10
def inner():
print(x)
return inner
inner = outer() inner()