pythonpytorchnlphuggingface-transformers

How to Visualize Cross-Attention Matrices in MarianMTModel During Output Generation


I am working on a machine translation task using the MarianMTModel from the Hugging Face transformers library. Specifically, I want to visualize the cross-attention matrices during the model's translation process. However, I encountered some difficulties in achieving this.

What I’ve Tried:

# VISUALIZING CROSS ATTENTION FOR TRANSLATION TASK (NOT WORKING YET)
from transformers import MarianMTModel, MarianTokenizer
import torch
import matplotlib.pyplot as plt
from torch.nn import functional as F

model_name = "Helsinki-NLP/opus-mt-en-de"
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
model.eval()

keys = {}
queries = {}

def get_key(layer):
    def hook(module, input, output):
        key, = input
        keys[layer] = key
    return hook

def get_query(layer):
    def hook(module, input, output):
        query, = input
        queries[layer] = query
    return hook

def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
        return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()

hooks = []
for i, layer in enumerate(model.model.decoder.layers):
    hooks.append(layer.encoder_attn.k_proj.register_forward_hook(get_key(i)))
    hooks.append(layer.encoder_attn.q_proj.register_forward_hook(get_query(i)))

input_text = "Please translate this to German."
inputs = tokenizer(input_text, return_tensors="pt")

translated_tokens = model.generate(**inputs, use_cache=False)

translated_text = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)

input_tokens = tokenizer.convert_ids_to_tokens(inputs["input_ids"][0])
output_tokens = tokenizer.convert_ids_to_tokens(translated_tokens[0])

attentions = []
for layer in range(len(keys)):
    K, Q = keys[layer], queries[layer]
    M = Q @ K.transpose(-2, -1)
    attentions.append(F.softmax(M, dim=-1))

attentions = torch.stack(attentions, dim=0)

print("layers, heads, output tokens, input tokens")
print(attentions.shape)
plt.figure(figsize=(10, 8))
plt.imshow(attentions[0, 0], cmap='viridis')
plt.colorbar()

plt.xticks(range(len(input_tokens)), input_tokens, rotation=90)
plt.yticks(range(len(output_tokens)), output_tokens)

plt.xlabel("Input Tokens")
plt.ylabel("Output Tokens")
plt.title("Cross-Attention Matrix")
plt.show()

This approach seemed to work in capturing the cross-attention matrices. However, I observed that the matrices only have 4 attention heads instead of the expected 8. This makes me question the correctness of my implementation.

My Question

Given the issues I’ve encountered, is there a more reliable method to extract and visualize the cross-attention matrices during the translation process? Additionally, if my current approach is fundamentally okay, how can I resolve the issue of capturing only 4 attention heads instead of 8?

I suspect that the issue might be related to that I'm currently not reshaping the key (K) and query (Q) tensors to the head dimension before multiplication, but I wanted to ask for advice in case there’s an easier or more effective way to do this.


Solution

  • Huggingface has built in methods to return attention weights

    translated_tokens = model.generate(**inputs, 
                                       output_attentions=True,
                                       return_dict_in_generate=True
                                      )
    
    print(translated_tokens.keys())
    > odict_keys(['sequences', 'encoder_attentions', 'decoder_attentions', 'cross_attentions', 'past_key_values'])
    

    With return_dict_in_generate=True, model.generate returns a dict-like object. With output_attentions=True, the output dict will contain all attention weights.

    For this model, it will include encoder attentions, decoder attentions and cross attentions.