I want to fine-tune BERT on a specific dataset. My problem is that I do not want to mask some tokens of my training dataset randomly, but I already have chosen which tokens I want to mask (for certain reasons).
To do so, I created a dataset that has two columns: text
in which some tokens have been replaced with [MASK]
(I am aware of the fact that some words could be tokenised with more than one token and I took care of that) and label
where I have the whole text.
Now I want to fine-tune a BERT model (say, bert-base-uncased) using Hugging Face's transformers
library, but I do not want to use DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm_probability=0.2)
where the masking is done randomly and I only can control the probability. What can I do?
This is what I did to solve my problem. I created a custom class and changed the tokenization in a way that I needed (mask one of the numerical spans in the input).
class CustomDataCollator(DataCollatorForLanguageModeling):
mlm: bool = True
return_tensors: str = "pt"
def __post_init__(self):
if self.mlm and self.tokenizer.mask_token is None:
raise ValueError(
"This tokenizer does not have a mask token which is necessary "
"for masked language modeling. You should pass `mlm=False` to "
"train on causal language modeling instead."
)
def torch_mask_tokens(self, inputs, special_tokens_mask):
"""
Prepare masked tokens inputs/labels for masked language modeling.
NOTE: keep `special_tokens_mask` as an argument for avoiding error
"""
# labels is batch_size x length of the sequence tensor
# with the original token id
# the length of the sequence includes the special tokens (2)
labels = inputs.clone()
batch_size = inputs.size(0)
# seq_len = inputs.size(1)
# in each seq, find the indices of the tokens that represent digits
dig_ids = [1014, 1015, 1016, 1017, 1018, 1019, 1020, 1021, 1022, 1023]
dig_idx = torch.zeros_like(labels)
for dig_id in dig_ids:
dig_idx += (labels == dig_id)
dig_idx = dig_idx.bool()
# in each seq, find the spans of Trues using `find_spans` function
spans = []
for i in range(batch_size):
spans.append(find_spans(dig_idx[i].tolist()))
masked_indices = torch.zeros_like(labels)
# spans is a list of lists of tuples
# in each tuple, the first element is the start index
# and the second element is the length
# in each child list, choose a random tuple
for i in range(batch_size):
if len(spans[i]) > 0:
idx = torch.randint(0, len(spans[i]), (1,))
start, length = spans[i][idx[0]]
masked_indices[i, start:start + length] = 1
else:
print("No digit found in the sequence!")
masked_indices = masked_indices.bool()
# We only compute loss on masked tokens
labels[~masked_indices] = -100
# change the input's masked_indices to self.tokenizer.mask_token
inputs[masked_indices] = self.tokenizer.mask_token_id
return inputs, labels
def find_spans(lst):
spans = []
for k, g in groupby(enumerate(lst), key=itemgetter(1)):
if k:
glist = list(g)
spans.append((glist[0][0], len(glist)))
return spans