I am working on a Jupyter notebook about Transformers. In the section on positional encodings, I want to demonstrate that the Transformer relies entirely on positional encoding to understand the order of the sequence. I previously learned from another question I posted that this concept only applies to models that don't use masked attention, like GPT-2. However, when I attempted the same approach with a BERT model (which uses cross-attention) to predict a [MASK] token, I encountered unexpected results.
What I expected to happen:
What actually happens: Sometimes the results align with my expectations, but other times, permuting one aspect (either the input IDs or positional embeddings) leads to different outcomes, even though occasionally, they produce the same result.
My question is: Is there something else in Hugging Face's BERT model that might be influenced by position, beyond just the positional encoding?
For completeness, I have included the full code from this part of the notebook below, so it can be tried out directly. The Important part happens in masked_prediction
.
import torch
import ipywidgets as widgets
from IPython.display import display
from transformers import BertForMaskedLM, AutoTokenizer
import matplotlib.pyplot as plt
import torch.nn.functional as F
# surpress renaming warnings
logging.getLogger("transformers.modeling_utils").setLevel(logging.ERROR)
warnings.simplefilter("ignore", FutureWarning)
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
input_ids = torch.Tensor([[]])
tokens = []
permutation = []
output = widgets.Output()
def permute_columns(matrix, permutation=None):
n = len(permutation)
first_n_columns = matrix[:, :n]
permuted_columns = first_n_columns[:, permutation]
remaining_columns = matrix[:, n:]
new_matrix = torch.hstack((permuted_columns, remaining_columns))
return new_matrix
def update_permutation(ordered_tags):
global permutation
fixed_tokens = [tokens[0]] + ordered_tags + [tokens[-1]]
permutation = [tokens.index(tag) for tag in fixed_tokens]
def tokenize(text):
global input_ids, tokens
input_ids = tokenizer(text, return_tensors="pt").input_ids
tokens = [tokenizer.decode([token_id]).strip() for token_id in input_ids[0]]
if len(tokens) > 2:
reorderable_tokens = tokens[1:-1]
else:
reorderable_tokens = []
with output:
output.clear_output(wait=True)
tags_input.allowed_tags = reorderable_tokens
tags_input.value = reorderable_tokens
update_permutation(tags_input.value)
def on_tags_change(change):
if len(change['new']) != len(tags_input.allowed_tags):
tags_input.value = tags_input.allowed_tags # Restore original value
def masked_prediction(input_ids, permutation, permute_input, permute_encoding):
with output:
output.clear_output(wait=True) # Clear previous outputs
if input_ids.numel() == 0:
print("You can't use an empty sequence for prediction")
return
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
if permute_encoding:
model.bert.embeddings.position_embeddings.weight.data = permute_columns(model.bert.embeddings.position_embeddings.weight.T, permutation).T
if permute_input:
input_ids = permute_columns(input_ids, permutation)
decoded_text = tokenizer.decode(input_ids[0], skip_special_tokens=False)
with torch.no_grad():
outputs = model(input_ids)
logits = outputs.logits
top_k = 5
mask_token_indices = torch.where(input_ids == tokenizer.mask_token_id)[1]
print(decoded_text, mask_token_indices, permutation)
num_masks = len(mask_token_indices)
if num_masks == 0:
print("You need to include a [MASK] token for prediction")
return
fig, axs = plt.subplots(1, num_masks, figsize=(15, 6))
if num_masks == 1:
axs = [axs]
for i, idx in enumerate(mask_token_indices):
mask_token_logits = logits[0, idx, :]
softmax_probs = F.softmax(mask_token_logits, dim=0)
top_token_probs, top_token_ids = torch.topk(softmax_probs, top_k, dim=0)
predicted_tokens = [tokenizer.decode([token_id]).strip() for token_id in top_token_ids]
predicted_confidences = top_token_probs.tolist()
axs[i].bar(predicted_tokens, predicted_confidences, color='blue')
axs[i].set_xlabel('Predicted Tokens')
axs[i].set_ylabel('Confidence')
axs[i].set_title(f'Masked Token at Position {idx.item()}')
axs[i].set_ylim(0, 1)
plt.show()
def on_predict_button_click(b):
masked_prediction(input_ids, permutation, permute_input_checkbox.value, permute_encoding_checkbox.value)
text_input = widgets.Text(placeholder='Write text here to encode.', description='Input:')
text_input.observe(lambda change: tokenize(change['new']), names='value')
tags_input = widgets.TagsInput(value=[], allowed_tags=[], allow_duplicates=False)
# Observe changes in tags order to update the permutation and prevent deletion
tags_input.observe(on_tags_change, names='value')
tags_input.observe(lambda change: update_permutation(change['new']), names='value')
# Create checkboxes for permute_input and permute_encoding
permute_input_checkbox = widgets.Checkbox(value=False, description='Permute Inputs')
permute_encoding_checkbox = widgets.Checkbox(value=False, description='Permute Encodings')
# Create a button to trigger the prediction
predict_button = widgets.Button(description="Run Prediction")
predict_button.on_click(on_predict_button_click)
# Display the widgets
display(text_input)
display(tags_input)
display(permute_input_checkbox)
display(permute_encoding_checkbox)
display(predict_button)
display(output)
The model inputs have token ids and position ids. There are four scenarios to consider:
You are correct that scenario 1 and 4 should produce the same results. However you are incorrect in assuming that permuting tokens or positions separately should give the same result. Consider:
# Given:
tokens = [0, 1, 2]
positions = [0, 1, 2]
permutation = [2, 0, 1]
# Ex1: Permute tokens but not positions
[2, 0, 1] # permuted tokens
[0, 1, 2] # standard positions
# Ex2: Permute positions but not tokens
[0, 1, 2] # standard tokens
[2, 0, 1] # permuted positions
In Ex1
, the model is told that token 2
occurs at position 0
. In Ex2
, the model is told that token 2
occurs at position 1
. Even though we used the same permutation, the mapping of tokens to positions is different. This results in different model outputs.
The reason you sometimes see these results line up is because you can (through random chance) sample a permutation that results in token/position embeddings lining up the same way (or mostly the same way) when permuting just one of them. This is luck - the average case produces different results.
It is simple to test this. Huggingface models take a position_ids
input parameter. We can use this to test permutations of the input ids without messing with the weight matrices.
To test this, we'll create input data, permute as needed, compute logits and compare logits.
When comparing logits, we will permute or depermute as needed to compare on a token to token basis. For example if token i
in scenario 1 is permuted to token j
in scenario 3, we want to compare logits i
from scenario 1 to logits j
in scenario 3.
import torch
from transformers import BertForMaskedLM, AutoTokenizer
def get_logits(inputs):
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
return logits
def permute_inputs(inputs, permutation, permute_ids=True, permute_positions=True):
outputs = {}
for k,v in inputs.items():
if k=='position_ids' and permute_positions:
outputs[k] = v[permutation]
elif k!='position_ids' and permute_ids:
outputs[k] = v[:,permutation]
else:
outputs[k] = v
return outputs
# load tokenizer/model
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
model = BertForMaskedLM.from_pretrained("bert-base-uncased")
model.eval() # remember to set model to eval
# create input ids and position ids
inputs = tokenizer('input text test sequence', return_tensors='pt')
inputs['position_ids'] = torch.tensor(list(range(inputs['input_ids'].shape[1])))
# create permutation tensor
permutation = torch.randperm(inputs['input_ids'].shape[1])
# compute scenario data
data = {
's1' : { # scenario 1 - baseline
'inputs' : inputs,
'permuted_ids' : False
},
's2' : { # scenario 2 - permute positions only
'inputs' : permute_inputs(inputs, permutation, permute_ids=False, permute_positions=True),
'permuted_ids' : False
},
's3' : { # scenario 3 - permute token ids only
'inputs' : permute_inputs(inputs, permutation, permute_ids=True, permute_positions=False),
'permuted_ids' : True
},
's4' : { # scenario 4 - permute tokens and positions
'inputs' : permute_inputs(inputs, permutation),
'permuted_ids' : True
}
}
# compute logits
for k,v in data.items():
v['logits'] = get_logits(v['inputs'])
comparisons = [
['s1', 's2'],
['s1', 's3'],
['s1', 's4'],
['s2', 's3'],
['s2', 's4'],
['s3', 's4'],
]
# compare scenarios
for sa, sb in comparisons:
data_a = data[sa]
data_b = data[sb]
logits_a = data_a['logits']
logits_b = data_b['logits']
if data_a['permuted_ids'] == data_b['permuted_ids']:
# either both logits are permuted or both logits are unpermuted
# so we can compare directly
val = (logits_a - logits_b).abs().mean()
elif data_a['permuted_ids'] and (not data_b['permuted_ids']):
# if `a` is permuted but `b` is not, we permute `b` to make tokens line up
val = (logits_a - logits_b[:,permutation]).abs().mean()
else:
# otherwise we permute `b` to make tokens line up
val = (logits_a[:,permutation] - logits_b).abs().mean()
print(f"Comparison {sa}, {sb}: {val.item():.6f}")
The code should produce an output like:
Comparison s1, s2: 1.407895
Comparison s1, s3: 1.583560
Comparison s1, s4: 0.000003
Comparison s2, s3: 1.750883
Comparison s2, s4: 1.407894
Comparison s3, s4: 1.583560
Run the code a bunch of times. You will find that the S1, S4
comparison always has a small deviation. This is because permuting tokens and positions together always produces the same result, ignoring small deviations caused by numeric issues.
You will find the S2, S3
comparison generally has a large deviation, but sometimes has a small deviation. As discussed, this is due to getting a lucky permutation where positions and ids mostly line up.