pythonpytorchhook

How exactly the forward and backward hooks work in PyTorch


I am trying to understand how exactly code-wise the hooks operate in PyTorch. I have a model and I would like to set a forward and backward hook in my code. I would like to set a hook in my model after a specific layer and I guess the easiest way is to set a hook to this specific module. This introductory video warns that the backward module contains a bug, but I am not sure if that is still the case.

My code looks as follows:

def __init__(self, model, attention_layer_name='desired_name_module',discard_ratio=0.9):
  self.model = model
  self.discard_ratio = discard_ratio
  for name, module in self.model.named_modules():
    if attention_layer_name in name:
        module.register_forward_hook(self.get_attention)
        module.register_backward_hook(self.get_attention_gradient)

  self.attentions = []
  self.attention_gradients = []

def get_attention(self, module, input, output):
  self.attentions.append(output.cpu())

def get_attention_gradient(self, module, grad_input, grad_output):
  self.attention_gradients.append(grad_input[0].cpu())

def __call__(self, input_tensor, category_index):
  self.model.zero_grad()
  output = self.model(input_tensor)
  loss = ...
  loss.backward()

I am puzzled to understand how code-wise the following lines work:

module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)

I am registering a hook to my desired module, however, then, I am calling a function in each case without any input. My question is Python-wise, how does this call work exactly? How the arguments of the register_forward_hook and register_backward_hook operate when the function it's called?


Solution

  • How does a hook work?

    A hook allows you to execute a specific function - referred to as a "callback" - when a particular action has been performed. In this case, you are expecting self.get_attention to be called once the forward function of module has been accessed. To give a minimal example of how a hook would look like. I define a simple class on which you can register new callbacks through register_hook, then when the instance is called (via __call__), all hooks will be called with the provided arguments:

    class Obj:
        def __init__(self):
            self.hooks = []
        
        def register_hook(self, hook):
            self.hooks.append(hook)
    
        def __call__(self, x, y):
            print('instance called')
            for hook in self.hooks:
                hook(x, y)
    

    First, implement two hooks for demonstration purposes:

    def foo(x, y):
        print(f'foo called with {x} and {y}')
    def bar(x, _):
        print(f'bar called with {x}')
    

    And initialize an instance of Obj:

    obj = Obj()
    

    You can register a hook and call the instance:

    >>> obj.register_hook(foo)
    >>> obj('yes', 'no')
    instance called
    foo called with yes and no
    

    You can add hooks on top and call again to compare, here both hooks are triggered:

    >>> obj.register_hook(bar)
    >>> obj('yes', 'no')
    instance called
    foo called with yes and no
    bar called with yes
    

    Using hooks in PyTorch

    There are two primary hooks in PyTorch: forward and backward. You also have pre- and post-hooks. Additionally there exists hooks on other actions such as load_state_dict...

    One side effect here is that you are registering the hook with self.get_attention and self.get_attention_gradient. The function passed to the register handler is not unbound to the class instance! In other words, on execution, these will be called without the self argument like:

    self.get_attention(module, input, output)
    self.get_attention_gradient(module, grad_input, grad_output)
    

    This will fail. A simple way to fix this is to wrap the hook with a lambda when you register it:

    module.register_forward_hook(
        lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
    

    All in all, your class could look like this:

    class Routine:
        def __init__(self, model, attention_layer_name):
            self.model = model
    
            for name, module in self.model.named_modules():
                if attention_layer_name in name:
                    module.register_forward_hook(
                        lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
                    module.register_full_backward_hook(
                        lambda *args, **kwargs: Routine.get_attention_gradient(self, *args, **kwargs))
    
            self.attentions = []
            self.attention_gradients = []
    
        def get_attention(self, module, input, output):
            self.attentions.append(output.cpu())
    
        def get_attention_gradient(self, module, grad_input, grad_output):
            self.attention_gradients.append(grad_input[0].cpu())
    
        def __call__(self, input_tensor):
            self.model.zero_grad()
            output = self.model(input_tensor)
            loss = output.mean()
            loss.backward()
    

    When initialized with a single linear layer model:

    routine = Routine(nn.Sequential(nn.Linear(10,10)), attention_layer_name='0')
    

    You can call the instance, this will first trigger the forward hook with (because of self.model(input_tensor), and then the backward hook (because of loss.backward()).

    >>> routine(torch.rand(1,10, requires_grad=True))
    

    Following your implementation, your forward hook is caching the output of the "attention_layer_name" layer in self.attentions.

    >>> routine.attentions
    [tensor([[-0.3137, -0.2265, -0.2197,  0.2211, -0.6700, 
              -0.5034, -0.1878, -1.1334,  0.2025,  0.8679]], grad_fn=<...>)]
    

    Similarly for the self.attention_gradients:

    >>> routine.attentions_gradients
    [tensor([[-0.0501,  0.0393,  0.0353, -0.0257,  0.0083,  
               0.0426, -0.0004, -0.0095, -0.0759, -0.0213]])] 
    

    It is important to note that the cached outputs and gradients will remain in self.attentions and self.attentions_gradients and get appended on every execution of Routine.__call__.