pytorchhuggingface-transformerstransformer-modelgpt-2

Why doesn't permuting positional encodings in GPT-2 affect the output as expected?


I'm trying to understand the role of positional encoding in the GPT-2 Transformer model. From what I understand, positional encodings are crucial because they give the model a sense of the order of tokens.

However, I'm confused about the behavior I'm observing:

  1. Permuting Positional Encodings: When I permute the positional encodings while keeping the input tokens the same, the generated output barely changes. I expected significant changes since the positional information should alter the model’s understanding of token order.

  2. Permuting Input Tokens: When I permute the input tokens (while permuting positional encodings in the same manner), the output changes significantly, but it doesn't revert to what it was with the original order.

This behavior is confusing because I expected the output to revert when both the positional encodings and tokens are permuted in the same way.

Could someone help clarify why this is happening? Is there something about how GPT-2 handles positional encoding that I'm missing? How can I modify my code to get the behavior I expect, where permuting both the positional encoding and input tokens in the same way results in the original output?

Thanks in advance!

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

def permute_columns(matrix, permutation=None):
    n = len(permutation)
    first_n_columns = matrix[:, :n]
    permuted_columns = first_n_columns[:, permutation]
    remaining_columns = matrix[:, n:]
    new_matrix = torch.hstack((permuted_columns, remaining_columns))
    return new_matrix

model_name = "gpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

permutation = [0, 4, 2, 3, 1]
# permute positional encoding
model.transformer.wpe.weight.data = permute_columns(model.transformer.wpe.weight.data.T, permutation).T

input_text = "The man ate the cow"
input_ids = tokenizer(input_text, return_tensors="pt").input_ids

# permute input
input_ids = permute_columns(input_ids, permutation)

outputs = model.generate(input_ids, max_length=50)
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)

I attempted to permute the positional encodings in a GPT-2 model, expecting this to change the generated output. Additionally, I tried permuting the input tokens along with the positional encodings in the same manner, anticipating that the output would revert to the original.

What I Expected:

What Actually Happened:


Solution

  • You have to consider the effect of the causal attention mask.

    CLMs like GPT-2 use an attention mask to prevent tokens from attending to tokens that come later in the sequence. This is important because if this wasn't the case, the model could "look ahead" and cheat at the next token prediction task. The attention mask restricts the model such that token i can only attend to tokens j <= i.

    Say we have tokens [0, 1, 2, 3, 4]. Token 0 attends to itself. Token 1 attends to [0, 1] and so on.

    Now consider your permutations.

    If we permute the positional embeddings, we give the model slightly different signal, but the token order and attention order is still the same. Token 0 still attends to token 0. Token 1 still attends to [0, 1], and so on. As a result, the output is mostly similar to the base case.

    Now we permute the token order, say to [3, 2, 0, 4, 1]. Token 3, which used to attend to [0, 1, 2, 3], can now only attend to itself. Token 2, which used to attend to [0, 1, 2] can now only attend to [3, 2]. Token permutation substantially changes what information is routed to what tokens, resulting in a substantial difference in model output.

    If you want to look at the effect of token order and positional embeddings in isolation, you should use a BERT-style masked language model that does not use a causal attention mask.