I use the following code to count how many % of words are encoded to unknown tokens.
paragraph_chinese = '...' # It is a long paragraph from a text file.
from transformers import AutoTokenizer, BertTokenizer
tokenizer_bart = BertTokenizer.from_pretrained("fnlp/bart-base-chinese")
encoded_chinese_bart = tokenizer_bart.encode(paragraph_chinese)
unk_token_id_bart = tokenizer_bart.convert_tokens_to_ids(["[UNK]"])
len_paragraph_chinese = len(paragraph_chinese)
unk_token_cnt_chinese_bart = encoded_chinese_bart.count(unk_token_id_bart[0])
print("BART Unknown Token count in Chinese Paragraph:", unk_token_cnt_chinese_bart, "(" + str(unk_token_cnt_chinese_bart * 100 / len_paragraph_chinese) + "%)")
print(type(tokenizer_bart))
which prints:
BART Unknown Token count in Chinese Paragraph: 1 (0.015938795027095953%)
<class 'transformers.models.bert.tokenization_bert.BertTokenizer'>
My question is: I noticed there is one unknown token. How can I know which word causes this unknown token?
p.s. I tried print(encoded_chinese_bart)
, but it is a list of Token IDs.
Using transformers 4.28.1
When you use the BertTokenizerFast instead of the "slow" version, you will get a BatchEncoding object that gives you access to several convenient methods that allow you to map a token back to the original string.
The following code uses the token_to_chars method:
from transformers import BertTokenizerFast
# just an example
paragraph_chinese = '马云 Kočka 祖籍浙江省嵊县 Kočka 现嵊州市'
tokenizer_bart = BertTokenizerFast.from_pretrained("fnlp/bart-base-chinese")
encoded_chinese_bart = tokenizer_bart(paragraph_chinese)
unk_token_id_bart = tokenizer_bart.unk_token_id
len_paragraph_chinese = len(paragraph_chinese)
unk_token_cnt_chinese_bart = encoded_chinese_bart.input_ids.count(unk_token_id_bart)
print(f'BART Unknown Token count in Chinese Paragraph: {unk_token_cnt_chinese_bart} ({unk_token_cnt_chinese_bart * 100 / len_paragraph_chinese}%)')
#find all indices
unk_indices = [i for i, x in enumerate(encoded_chinese_bart.input_ids) if x == unk_token_id_bart]
for unk_i in unk_indices:
start, stop = encoded_chinese_bart.token_to_chars(unk_i)
print(f"At {start}:{stop}: {paragraph_chinese[start:stop]}")
Original:
BART Unknown Token count in Chinese Paragraph: 2 (7.407407407407407%)
At 3:8: Kočka
At 17:22: Kočka