The following are my operating details.
First I download the needed files from the official website. These files included config.json, merges.txt, pytorch_model.bin, tokenizer.json, tokenizer_config.json and vocab.json. Then I stored them in the root path of the project ./gpt2.
Second, I loaded the model and predicted the next word based on the input context. The code is displayed as follows.
model = GPT2Model.from_pretained('./gpt2')
gpt_tokenizer=GPT2Tokenizer.from_pretrained('./gpt2')
start_context="The white man worked as a "
ids_text=gpt_tokenizer(start_ontext,return_tensor='pt')
output=model(**ids_text)
output=output.last_hidden_state[:,-1,:]
idx_next=torch.argmax(output,dim=-1,keepdim=True)
ids=idx_next.squeeze(0)
text=gpt_tokenizer.decode(ids.tolist())
print(text)
Here, the text always indicates age, even though I changed the start_context to other, like "I see a cat under".
I hope someone can tell me the reason and help me work it out, thanks.
The reason is that you got the ouput of shape [batch, hidden_size], which is 1,768 I guess. You cannot fit it into a argmax and do tokenization as 768 is the dimension of vector space instead of vocab. Try using GPT2LMHeadModel:
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
import torch
# Model and tokenizer paths
model_path = "/mnt/sda/agent_mxz/models/gpt2"
tokenizer = GPT2Tokenizer.from_pretrained(model_path)
model = GPT2LMHeadModel.from_pretrained(model_path)
# Input texts
texts = ["Replace me by any text you'd like.", "Hello, this is", "Write a story for me."]
# Ensure padding is done on the left
tokenizer.padding_side = "left"
# Define PAD Token = EOS Token
tokenizer.pad_token = tokenizer.eos_token
model.config.pad_token_id = model.config.eos_token_id
model.config.do_sample = False
# Tokenize the inputs with padding
encoded_inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
print(encoded_inputs)
# Get model output
outputs = model(**encoded_inputs)
# outputs1 = model.generate()
# Print the outputs
# print(outputs[0][1])
print(tokenizer.batch_decode(torch.argmax(outputs[0], dim=-1)))
It will give you ['. the with a means, want like.\n', ',HelloHelloHelloHelloHello Hello Hello hello hello', ',Write""Write write write the I.']