I am trying to understand how exactly code-wise the hooks operate in PyTorch
. I have a model and I would like to set a forward and backward hook in my code. I would like to set a hook in my model after a specific layer and I guess the easiest way is to set a hook to this specific module
. This introductory video warns that the backward module contains a bug, but I am not sure if that is still the case.
My code looks as follows:
def __init__(self, model, attention_layer_name='desired_name_module',discard_ratio=0.9):
self.model = model
self.discard_ratio = discard_ratio
for name, module in self.model.named_modules():
if attention_layer_name in name:
module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)
self.attentions = []
self.attention_gradients = []
def get_attention(self, module, input, output):
self.attentions.append(output.cpu())
def get_attention_gradient(self, module, grad_input, grad_output):
self.attention_gradients.append(grad_input[0].cpu())
def __call__(self, input_tensor, category_index):
self.model.zero_grad()
output = self.model(input_tensor)
loss = ...
loss.backward()
I am puzzled to understand how code-wise the following lines work:
module.register_forward_hook(self.get_attention)
module.register_backward_hook(self.get_attention_gradient)
I am registering a hook to my desired module, however, then, I am calling a function in each case without any input. My question is Python
-wise, how does this call work exactly? How the arguments of the register_forward_hook
and register_backward_hook
operate when the function it's called?
A hook allows you to execute a specific function - referred to as a "callback" - when a particular action has been performed. In this case, you are expecting self.get_attention
to be called once the forward
function of module
has been accessed. To give a minimal example of how a hook would look like. I define a simple class on which you can register new callbacks through register_hook
, then when the instance is called (via __call__
), all hooks will be called with the provided arguments:
class Obj:
def __init__(self):
self.hooks = []
def register_hook(self, hook):
self.hooks.append(hook)
def __call__(self, x, y):
print('instance called')
for hook in self.hooks:
hook(x, y)
First, implement two hooks for demonstration purposes:
def foo(x, y):
print(f'foo called with {x} and {y}')
def bar(x, _):
print(f'bar called with {x}')
And initialize an instance of Obj
:
obj = Obj()
You can register a hook and call the instance:
>>> obj.register_hook(foo)
>>> obj('yes', 'no')
instance called
foo called with yes and no
You can add hooks on top and call again to compare, here both hooks are triggered:
>>> obj.register_hook(bar)
>>> obj('yes', 'no')
instance called
foo called with yes and no
bar called with yes
There are two primary hooks in PyTorch: forward and backward. You also have pre- and post-hooks. Additionally there exists hooks on other actions such as load_state_dict
...
To attach a hook on the forward process of a nn.Module
, you should use register_forward_hook
, the argument is a callback function that expects module
, args
, and output
. This callback will be triggered on every forward execution.
For backward hooks, you should use register_full_backward_hook
, the registered hook expects three arguments: module
, grad_input
, and grad_output
. As of recent PyTorch versions, register_backward_hook
has been deprecated and should not be used.
One side effect here is that you are registering the hook with self.get_attention
and self.get_attention_gradient
. The function passed to the register handler is not unbound to the class instance! In other words, on execution, these will be called without the self
argument like:
self.get_attention(module, input, output)
self.get_attention_gradient(module, grad_input, grad_output)
This will fail. A simple way to fix this is to wrap the hook with a lambda when you register it:
module.register_forward_hook(
lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
All in all, your class could look like this:
class Routine:
def __init__(self, model, attention_layer_name):
self.model = model
for name, module in self.model.named_modules():
if attention_layer_name in name:
module.register_forward_hook(
lambda *args, **kwargs: Routine.get_attention(self, *args, **kwargs))
module.register_full_backward_hook(
lambda *args, **kwargs: Routine.get_attention_gradient(self, *args, **kwargs))
self.attentions = []
self.attention_gradients = []
def get_attention(self, module, input, output):
self.attentions.append(output.cpu())
def get_attention_gradient(self, module, grad_input, grad_output):
self.attention_gradients.append(grad_input[0].cpu())
def __call__(self, input_tensor):
self.model.zero_grad()
output = self.model(input_tensor)
loss = output.mean()
loss.backward()
When initialized with a single linear layer model:
routine = Routine(nn.Sequential(nn.Linear(10,10)), attention_layer_name='0')
You can call the instance, this will first trigger the forward hook with (because of self.model(input_tensor)
, and then the backward hook (because of loss.backward()
).
>>> routine(torch.rand(1,10, requires_grad=True))
Following your implementation, your forward hook is caching the output of the "attention_layer_name"
layer in self.attentions
.
>>> routine.attentions
[tensor([[-0.3137, -0.2265, -0.2197, 0.2211, -0.6700,
-0.5034, -0.1878, -1.1334, 0.2025, 0.8679]], grad_fn=<...>)]
Similarly for the self.attention_gradients
:
>>> routine.attentions_gradients
[tensor([[-0.0501, 0.0393, 0.0353, -0.0257, 0.0083,
0.0426, -0.0004, -0.0095, -0.0759, -0.0213]])]
It is important to note that the cached outputs and gradients will remain in self.attentions
and self.attentions_gradients
and get appended on every execution of Routine.__call__
.