In order to locate a bug, I am trying to introspect the backward calculation in PyTorch. Following the description of torch's Autograd mechanics, I added backward hooks to each parameter of my model as well as hooks on the grad_fn
of each activation. The following code snippet illustrates how I add the hooks to the grad_fn
:
import torch.distributed as dist
def make_hook(grad_fn, note=None):
if grad_fn is not None and grad_fn.name is not None:
def hook(*args, **kwargs):
print(f"[{dist.get_rank()}] {grad_fn.name()} with {len(args)} args "
f"and {len(kwargs)} kwargs [{note or '/'}]")
return hook
else:
return None
def register_hooks_on_grads(grad_fn, make_hook_fn):
if not grad_fn:
return
hook = make_hook_fn(grad_fn)
if hook:
grad_fn.register_hook(hook)
for fn, _ in grad_fn.next_functions:
if not fn:
continue
var = getattr(fn, "variable", None)
if var is None:
register_hooks_on_grads(fn, make_hook_fn)
x = torch.zeros(15, requires_grad=True)
y = x.exp()
z = y.sum()
register_hooks_on_grads(z.grad_fn, make_hook)
When running my model, I noticed that each invocation of hook
gets two arguments and no key-word arguments. In case of a AddBackward
function, the first argument is a list of two tensors, the second argument is a list of one tensor. The same holds true for the LinearWithGradAccumulationAndAsyncCommunicationBackward
function. In case of a MeanBackward
function, both arguments are lists with one tensor each.
My conjecture to this is that the first argument probably contains the inputs to the operator (or whatever was saved with ctx.save_for_backward
) and that the second argument contains the gradients. Am I right with this? Can I just replicate the backward computation with grad_fn(*args)
or is there more to it (e.g., state)?
Unfortunately, I didn't find any documentation on this. I am grateful for any pointer towards the relevant documentation.
After revisiting the above-mentioned documentation, I noticed that registering hooks on nodes refers to grad_fn.register_hook
and that there are 2 different hooks for nodes: one executed before the node is executed, and one after. In my code above, I only registered hooks that are run after the node is executed, so when I run my training code and the backward operator reported an error, I couldn't see the currently run operator, only the last successful operator. After I registered a prehook on the node, it worked:
def register_hooks_on_grads(grad_fn, make_hook_fn):
if not grad_fn:
return
prehook, posthook = make_hook_fn(grad_fn)
if prehook:
grad_fn.register_prehook(prehook)
if posthook:
grad_fn.register_hook(posthook)
for fn, _ in grad_fn.next_functions:
if not fn:
continue
var = getattr(fn, "variable", None)
if var is None:
register_hooks_on_grads(fn, make_hook_fn)
The pre-hook is executed before the backward function is executed and gets as input the current gradients from the preceding backward function. The post-hook is executed after the backward function and additionally gets the output of grad_fn
.
In fact, I can use grad_fn(*args, **kwargs)
to replicate the backward computation, where args
and kwargs
are the input to the prehook function.