pythonpytorchtorchautograd

Understanding and introspecting torch.autograd.backward


In order to locate a bug, I am trying to introspect the backward calculation in PyTorch. Following the description of torch's Autograd mechanics, I added backward hooks to each parameter of my model as well as hooks on the grad_fn of each activation. The following code snippet illustrates how I add the hooks to the grad_fn:

import torch.distributed as dist


def make_hook(grad_fn, note=None):
    if grad_fn is not None and grad_fn.name is not None:
        def hook(*args, **kwargs):
            print(f"[{dist.get_rank()}] {grad_fn.name()} with {len(args)} args "
                  f"and {len(kwargs)} kwargs [{note or '/'}]")
        return hook
    else:
        return None


def register_hooks_on_grads(grad_fn, make_hook_fn):
    if not grad_fn:
        return
    hook = make_hook_fn(grad_fn)
    if hook:
        grad_fn.register_hook(hook)
    for fn, _ in grad_fn.next_functions:
        if not fn:
            continue
        var = getattr(fn, "variable", None)
        if var is None:
            register_hooks_on_grads(fn, make_hook_fn)


x = torch.zeros(15, requires_grad=True)
y = x.exp()
z = y.sum()
register_hooks_on_grads(z.grad_fn, make_hook)

When running my model, I noticed that each invocation of hook gets two arguments and no key-word arguments. In case of a AddBackward function, the first argument is a list of two tensors, the second argument is a list of one tensor. The same holds true for the LinearWithGradAccumulationAndAsyncCommunicationBackward function. In case of a MeanBackward function, both arguments are lists with one tensor each.

My conjecture to this is that the first argument probably contains the inputs to the operator (or whatever was saved with ctx.save_for_backward) and that the second argument contains the gradients. Am I right with this? Can I just replicate the backward computation with grad_fn(*args) or is there more to it (e.g., state)?

Unfortunately, I didn't find any documentation on this. I am grateful for any pointer towards the relevant documentation.


Solution

  • After revisiting the above-mentioned documentation, I noticed that registering hooks on nodes refers to grad_fn.register_hook and that there are 2 different hooks for nodes: one executed before the node is executed, and one after. In my code above, I only registered hooks that are run after the node is executed, so when I run my training code and the backward operator reported an error, I couldn't see the currently run operator, only the last successful operator. After I registered a prehook on the node, it worked:

    def register_hooks_on_grads(grad_fn, make_hook_fn):
        if not grad_fn:
            return
        prehook, posthook = make_hook_fn(grad_fn)
        if prehook:
            grad_fn.register_prehook(prehook)
        if posthook:
            grad_fn.register_hook(posthook)
        for fn, _ in grad_fn.next_functions:
            if not fn:
                continue
            var = getattr(fn, "variable", None)
            if var is None:
                register_hooks_on_grads(fn, make_hook_fn)
    

    The pre-hook is executed before the backward function is executed and gets as input the current gradients from the preceding backward function. The post-hook is executed after the backward function and additionally gets the output of grad_fn.

    In fact, I can use grad_fn(*args, **kwargs) to replicate the backward computation, where args and kwargs are the input to the prehook function.