bert-language-modelhuggingface-transformers

BERT sentence embeddings from transformers


I'm trying to get sentence vectors from hidden states in a BERT model. Looking at the huggingface BertModel instructions here, which say:

from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
model = BertModel.from_pretrained("bert-base-multilingual-cased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt') 
output = model(**encoded_input)

So first note, as it is on the website, this does /not/ run. You get:

>>> Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
TypeError: 'BertTokenizer' object is not callable

But it looks like a minor change fixes it, in that you don't call the tokenizer directly, but ask it to encode the input:

encoded_input = tokenizer.encode(text, return_tensors="pt")
output = model(encoded_input)

OK, that aside, the tensors I get, however, have a different shape than I expected:

>>> output[0].shape
torch.Size([1,11,768])

This is a lot of layers. Which is the correct layer to use for sentence embeddings? [0]? [-1]? Averaging several? I have the goal of being able to do cosine similarity with these, so I need a proper 1xN vector rather than an NxK tensor.

I see that the popular bert-as-a-service project appears to use [0]

Is this correct? Is there documentation for what each of the layers are?


Solution

  • I don't think there is single authoritative documentation saying what to use and when. You need to experiment and measure what is best for your task. Recent observations about BERT are nicely summarized in this paper: https://arxiv.org/pdf/2002.12327.pdf.

    I think the rule of thumb is:

    Bert-as-services uses the last layer by default (but it is configurable). Here, it would be [:, -1]. However, it always returns a list of vectors for all input tokens. The vector corresponding to the first special (so-called [CLS]) token is considered to be the sentence embedding. This where the [0] comes from in the snipper you refer to.