I am working on a machine translation task using the MarianMTModel from the Hugging Face transformers library. Specifically, I want to visualize the cross-attention matrices during the model's translation process. However, I encountered some difficulties in achieving this.
What I’ve Tried:
Initial Attempt: I noticed that the cross-attention matrices are not directly returned when the model generates a translation. The only example I found involved feeding both the source text and the translation to the model. However, my goal is to access the cross-attention matrices while the model generates the output, not for a translation given by me.
Using Forward Hooks: To achieve this, I implemented forward hooks on both the key and query projections of the attention mechanism, while disabling the key-value caching (use_cache=False) to capture the full matrices at the last step. Here’s my implementation:
# VISUALIZING CROSS ATTENTION FOR TRANSLATION TASK (NOT WORKING YET)
from transformers import MarianMTModel, MarianTokenizer
import torch
import matplotlib.pyplot as plt
from torch.nn import functional as F
model_name = "Helsinki-NLP/opus-mt-en-de"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
model.eval()
keys = {}
queries = {}
def get_key(layer):
def hook(module, input, output):
key, = input
keys[layer] = key
return hook
def get_query(layer):
def hook(module, input, output):
query, = input
queries[layer] = query
return hook
def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
hooks = []
for i, layer in enumerate(model.model.decoder.layers):
hooks.append(layer.encoder_attn.k_proj.register_forward_hook(get_key(i)))
hooks.append(layer.encoder_attn.q_proj.register_forward_hook(get_query(i)))
input_text = "Please translate this to German."
inputs = tokenizer(input_text, return_tensors="pt")
translated_tokens = model.generate(**inputs, use_cache=False)
translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
input_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
output_tokens = tokenizer.convert_ids_to_tokens(translated_tokens[0])
attentions = []
for layer in range(len(keys)):
K, Q = keys[layer], queries[layer]
M = Q @ K.transpose(-2, -1)
attentions.append(F.softmax(M, dim=-1))
attentions = torch.stack(attentions, dim=0)
print("layers, heads, output tokens, input tokens")
print(attentions.shape)
plt.figure(figsize=(10, 8))
plt.imshow(attentions[0, 0], cmap='viridis')
plt.colorbar()
plt.xticks(range(len(input_tokens)), input_tokens, rotation=90)
plt.yticks(range(len(output_tokens)), output_tokens)
plt.xlabel("Input Tokens")
plt.ylabel("Output Tokens")
plt.title("Cross-Attention Matrix")
plt.show()
This approach seemed to work in capturing the cross-attention matrices. However, I observed that the matrices only have 4 attention heads instead of the expected 8. This makes me question the correctness of my implementation.
My Question
Given the issues I’ve encountered, is there a more reliable method to extract and visualize the cross-attention matrices during the translation process? Additionally, if my current approach is fundamentally okay, how can I resolve the issue of capturing only 4 attention heads instead of 8?
I suspect that the issue might be related to that I'm currently not reshaping the key (K) and query (Q) tensors to the head dimension before multiplication, but I wanted to ask for advice in case there’s an easier or more effective way to do this.
Huggingface has built in methods to return attention weights
translated_tokens = model.generate(**inputs,
output_attentions=True,
return_dict_in_generate=True
)
print(translated_tokens.keys())
> odict_keys(['sequences', 'encoder_attentions', 'decoder_attentions', 'cross_attentions', 'past_key_values'])
With return_dict_in_generate=True
, model.generate
returns a dict-like object. With output_attentions=True
, the output dict will contain all attention weights.
For this model, it will include encoder attentions, decoder attentions and cross attentions.