machine-learningpytorchnlphuggingface-transformerslanguage-model

How to get the embedding of any vocabulary token in GPT?


I have a GPT model

model = BioGptForCausalLM.from_pretrained("microsoft/biogpt").to(device)

When I send my batch to it I can get the logits and the hidden states:

out = model(batch["input_ids"].to(device), output_hidden_states=True, return_dict=True)
print(out.keys())
>>> odict_keys(['logits', 'past_key_values', 'hidden_states'])

The logits have shape of

torch.Size([2, 1024, 42386]) # batch of size 2, sequence length = 1024, vocab size = 42386

Corresponding to (batch, seq_length, vocab_length). If I understand correctly, for each token in the sequence, the logits is a vector of size vocab_length which points the model to which token from the vocabulary to use, after passing it to softmax. I believe that each of these tokens should have an embedding.

From my previous question I found how to get the embeddings of each sequence token (shape [2,1024,1024] in my setting). But, how can I get the embeddings of each token in the vocabulary of the model? This should be of size [2, 1024, 42386, 1024] (BioGPT has a hidden size of length 1024).

I'm mainly interested in just a few special tokens (e.g., indices 1,2,6,112 out of the 42386).


Solution

  • If I understand correctly, you want an embedding representing a single token from the vocabulary. They are two answers that I know for that, depending on which embedding you want exactly.

    1st solution

    The first layer in the model is a torch.nn.Embedding, which is under the hood a linear layer with no bias, so it has a weight parameter of shape [V, D] where V is the vocab size (42386 for you) and D is the dimension of the embedding (1024). You can access to the representation of a token k with : model.biogpt.embed_tokens.weight[k]. This is the 1024-sized vector that directly represents the k-th token.

    2nd solution

    You can feed the model with a created sequence, containing just the token of which you want the representation. This representation corresponds to the input of the first attention layer of the model. For example, to get the 5th token representation:

    inp = torch.Tensor([[5]]).long()
    output = model(inp, output_hidden_states=True)
    print(output.hidden_states[0])
    

    These two representations are not exactly the same, because the first one only represents a token, while the second represents the token in its sentence, which is a sequence of one single token. It is up to you to decide which one suits to what you want to do after.