pythonpytorchnlphuggingface-transformers

Why doesn't permuting positional encodings in BERT affect the output as expected?


I am working on a Jupyter notebook about Transformers. In the section on positional encodings, I want to demonstrate that the Transformer relies entirely on positional encoding to understand the order of the sequence. I previously learned from another question I posted that this concept only applies to models that don't use masked attention, like GPT-2. However, when I attempted the same approach with a BERT model (which uses cross-attention) to predict a [MASK] token, I encountered unexpected results.

What I expected to happen:

What actually happens: Sometimes the results align with my expectations, but other times, permuting one aspect (either the input IDs or positional embeddings) leads to different outcomes, even though occasionally, they produce the same result.

My question is: Is there something else in Hugging Face's BERT model that might be influenced by position, beyond just the positional encoding?

For completeness, I have included the full code from this part of the notebook below, so it can be tried out directly. The Important part happens in masked_prediction.

import torch
import ipywidgets as widgets
from IPython.display import display
from transformers import BertForMaskedLM, AutoTokenizer
import matplotlib.pyplot as plt
import torch.nn.functional as F

# surpress renaming warnings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
warnings.simplefilter("ignore", FutureWarning)

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

input_ids = torch.Tensor([[]])
tokens = []
permutation = []

output = widgets.Output()

def permute_columns(matrix, permutation=None):
    n = len(permutation)
    first_n_columns = matrix[:, :n]
    permuted_columns = first_n_columns[:, permutation]
    remaining_columns = matrix[:, n:]
    new_matrix = torch.hstack((permuted_columns, remaining_columns))
    return new_matrix

def update_permutation(ordered_tags):
    global permutation
    fixed_tokens = [tokens[0]] + ordered_tags + [tokens[-1]]
    
    permutation = [tokens.index(tag) for tag in fixed_tokens]
    

def tokenize(text):
    global input_ids, tokens
    input_ids = tokenizer(text, return_tensors="pt").input_ids
    tokens = [tokenizer.decode([token_id]).strip() for token_id in input_ids[0]]
    
    if len(tokens) > 2:
        reorderable_tokens = tokens[1:-1]
    else:
        reorderable_tokens = []
    
    with output:
        output.clear_output(wait=True)
        tags_input.allowed_tags = reorderable_tokens
        tags_input.value = reorderable_tokens
        update_permutation(tags_input.value)

def on_tags_change(change):
    if len(change['new']) != len(tags_input.allowed_tags):
        tags_input.value = tags_input.allowed_tags  # Restore original value


def masked_prediction(input_ids, permutation, permute_input, permute_encoding):
    
    with output:
        output.clear_output(wait=True)  # Clear previous outputs
        
        if input_ids.numel() == 0:
            print("You can't use an empty sequence for prediction")
            return
        
        model = BertForMaskedLM.from_pretrained("bert-base-uncased")
        
        if permute_encoding:
            model.bert.embeddings.position_embeddings.weight.data = permute_columns(model.bert.embeddings.position_embeddings.weight.T, permutation).T
        if permute_input:
            input_ids = permute_columns(input_ids, permutation)
            
        decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=False)
            
        with torch.no_grad():
            outputs = model(input_ids)
            
        logits = outputs.logits

        top_k = 5

        mask_token_indices = torch.where(input_ids == tokenizer.mask_token_id)[1]
        print(decoded_text, mask_token_indices, permutation)
        num_masks = len(mask_token_indices)
        if num_masks == 0:
            print("You need to include a [MASK] token for prediction")
            return

        fig, axs = plt.subplots(1, num_masks, figsize=(15, 6))
        
        if num_masks == 1:
            axs = [axs]

        for i, idx in enumerate(mask_token_indices):
            mask_token_logits = logits[0, idx, :]

            softmax_probs = F.softmax(mask_token_logits, dim=0)

            top_token_probs, top_token_ids = torch.topk(softmax_probs, top_k, dim=0)

            predicted_tokens = [tokenizer.decode([token_id]).strip() for token_id in top_token_ids]
            predicted_confidences = top_token_probs.tolist()

            axs[i].bar(predicted_tokens, predicted_confidences, color='blue')
            axs[i].set_xlabel('Predicted Tokens')
            axs[i].set_ylabel('Confidence')
            axs[i].set_title(f'Masked Token at Position {idx.item()}')
            axs[i].set_ylim(0, 1)

        plt.show()

def on_predict_button_click(b):
    masked_prediction(input_ids, permutation, permute_input_checkbox.value, permute_encoding_checkbox.value)

text_input = widgets.Text(placeholder='Write text here to encode.', description='Input:')
text_input.observe(lambda change: tokenize(change['new']), names='value')
tags_input = widgets.TagsInput(value=[], allowed_tags=[], allow_duplicates=False)

# Observe changes in tags order to update the permutation and prevent deletion
tags_input.observe(on_tags_change, names='value')
tags_input.observe(lambda change: update_permutation(change['new']), names='value')

# Create checkboxes for permute_input and permute_encoding
permute_input_checkbox = widgets.Checkbox(value=False, description='Permute Inputs')
permute_encoding_checkbox = widgets.Checkbox(value=False, description='Permute Encodings')

# Create a button to trigger the prediction
predict_button = widgets.Button(description="Run Prediction")
predict_button.on_click(on_predict_button_click)

# Display the widgets
display(text_input)
display(tags_input)
display(permute_input_checkbox)
display(permute_encoding_checkbox)
display(predict_button)
display(output)

Solution

  • The model inputs have token ids and position ids. There are four scenarios to consider:

    1. Baseline. Correct order for tokens and positions
    2. Permute position ids only
    3. Permute token ids only
    4. Permute position ids and token ids

    You are correct that scenario 1 and 4 should produce the same results. However you are incorrect in assuming that permuting tokens or positions separately should give the same result. Consider:

    # Given:
    tokens = [0, 1, 2]
    positions = [0, 1, 2]
    permutation = [2, 0, 1]
    
    # Ex1: Permute tokens but not positions
    [2, 0, 1] # permuted tokens
    [0, 1, 2] # standard positions
    
    # Ex2: Permute positions but not tokens
    [0, 1, 2] # standard tokens
    [2, 0, 1] # permuted positions
    

    In Ex1, the model is told that token 2 occurs at position 0. In Ex2, the model is told that token 2 occurs at position 1. Even though we used the same permutation, the mapping of tokens to positions is different. This results in different model outputs.

    The reason you sometimes see these results line up is because you can (through random chance) sample a permutation that results in token/position embeddings lining up the same way (or mostly the same way) when permuting just one of them. This is luck - the average case produces different results.

    It is simple to test this. Huggingface models take a position_ids input parameter. We can use this to test permutations of the input ids without messing with the weight matrices.

    To test this, we'll create input data, permute as needed, compute logits and compare logits.

    When comparing logits, we will permute or depermute as needed to compare on a token to token basis. For example if token i in scenario 1 is permuted to token j in scenario 3, we want to compare logits i from scenario 1 to logits j in scenario 3.

    import torch
    from transformers import BertForMaskedLM, AutoTokenizer
    
    def get_logits(inputs):
        with torch.no_grad():
            outputs = model(**inputs)  
            logits = outputs.logits
        return logits
    
    def permute_inputs(inputs, permutation, permute_ids=True, permute_positions=True):
        outputs = {}
        for k,v in inputs.items():
            if k=='position_ids' and permute_positions:
                outputs[k] = v[permutation]
            elif k!='position_ids' and permute_ids:
                outputs[k] = v[:,permutation]
            else:
                outputs[k] = v
                
        return outputs
    
    # load tokenizer/model
    tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
    model = BertForMaskedLM.from_pretrained("bert-base-uncased")
    model.eval() # remember to set model to eval
    
    # create input ids and position ids
    inputs = tokenizer('input text test sequence', return_tensors='pt')
    
    inputs['position_ids'] = torch.tensor(list(range(inputs['input_ids'].shape[1])))
    
    # create permutation tensor
    permutation = torch.randperm(inputs['input_ids'].shape[1])
    
    # compute scenario data
    data = {
        's1' : { # scenario 1 - baseline
            'inputs' : inputs,
            'permuted_ids' : False
        },
        's2' : { # scenario 2 - permute positions only
            'inputs' : permute_inputs(inputs, permutation, permute_ids=False, permute_positions=True),
            'permuted_ids' : False
        },
        's3' : { # scenario 3 - permute token ids only
            'inputs' : permute_inputs(inputs, permutation, permute_ids=True, permute_positions=False),
            'permuted_ids' : True
        },
        's4' : { # scenario 4 - permute tokens and positions
            'inputs' : permute_inputs(inputs, permutation),
            'permuted_ids' : True
        }
    }
    
    # compute logits
    for k,v in data.items():
        v['logits'] = get_logits(v['inputs'])
    
    comparisons = [
        ['s1', 's2'],
        ['s1', 's3'],
        ['s1', 's4'],
        ['s2', 's3'],
        ['s2', 's4'],
        ['s3', 's4'],
    ]
    
    # compare scenarios 
    for sa, sb in comparisons:
        data_a = data[sa]
        data_b = data[sb]
        
        logits_a = data_a['logits']
        logits_b = data_b['logits']
        
        if data_a['permuted_ids'] == data_b['permuted_ids']:
            # either both logits are permuted or both logits are unpermuted
            # so we can compare directly
            val = (logits_a - logits_b).abs().mean()
        elif data_a['permuted_ids'] and (not data_b['permuted_ids']):
            # if `a` is permuted but `b` is not, we permute `b` to make tokens line up
            val = (logits_a - logits_b[:,permutation]).abs().mean()
        else:
            # otherwise we permute `b` to make tokens line up
            val = (logits_a[:,permutation] - logits_b).abs().mean()
            
        print(f"Comparison {sa}, {sb}: {val.item():.6f}")
    

    The code should produce an output like:

    Comparison s1, s2: 1.407895
    Comparison s1, s3: 1.583560
    Comparison s1, s4: 0.000003
    Comparison s2, s3: 1.750883
    Comparison s2, s4: 1.407894
    Comparison s3, s4: 1.583560
    

    Run the code a bunch of times. You will find that the S1, S4 comparison always has a small deviation. This is because permuting tokens and positions together always produces the same result, ignoring small deviations caused by numeric issues.

    You will find the S2, S3 comparison generally has a large deviation, but sometimes has a small deviation. As discussed, this is due to getting a lucky permutation where positions and ids mostly line up.