I am using a DistilBART for abstractive summarization. The method generate()
is very straightforward to use. However, it returns complete, finished summaries. What I want is, at each step, access the logits to then get the list of next-word candidates and choose based on my own criteria. Once chosen, continue with the next word and so on until the EOS token is produced.
I am aware that I can access the logits by doing model(**input).logits[:, -1, :]
, but here the input would be the whole (encoded) text, so what would exactly these logits correspond with? The first generated token? The last?
Thank you for your answers!
For future reference, here is how it can be done (note: this is specific to encoder-decoder models, like BART):
1. Initialization
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
# Load model
tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-1-1")
model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-1-1")
text = "..."
# Tokenize text
batch = tokenizer(text, return_tensors="pt")
2. Example 1: Summary generation with greedy decoding (no cache)
generated_sequence = torch.tensor([[tokenizer.sep_token_id]]) # initial token
# Generation loop
while True:
with torch.no_grad():
output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
next_token_logits = output.logits[:, -1, :]
next_token_scores = next_token_logits.softmax(dim=-1)
# Take token with highest probability
next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)
# Append token to generated sequence
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
# Stop if EOS token generated
if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
break
summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
3. Example 2: Summary generation with top-k, top-p sampling & temperature (no cache)
from transformers.generation_utils import top_k_top_p_filtering
temperature = 0.7
generated_sequence = torch.tensor([[tokenizer.sep_token_id]]) # initial token
# Generation loop
while True:
with torch.no_grad():
output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
logits = output.logits[:, -1, :] / temperature # apply temperature
filtered_logits = top_k_top_p_filtering(logits=logits, top_k=4, top_p=0.7)
probabilities = filtered_logits.softmax(dim=-1)
# Sample next token
next_token = torch.multinomial(probabilities, 1)
# Append token to generated sequence
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
# Stop if EOS token generated
if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
break
summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
(Other generating strategies would be analogous).
4. Using cache
Since the input to the encoder (i.e., the text to be summarized) is always the same, we can cache it to greatly speed up the generation.
generated_sequence = torch.tensor([[tokenizer.sep_token_id]]) # initial token
input_ids = batch["input_ids"]
past_key_values = None
with torch.no_grad():
output = model(
input_ids=input_ids,
decoder_input_ids=generated_sequence,
past_key_values=past_key_values
)
encoder_outputs=output.encoder_last_hidden_state
# Generation loop
while True:
# From here on, use cached attention
past_key_values = output.past_key_values
next_token_logits = output.logits[:, -1, :]
next_token_scores = next_token_logits.softmax(dim=-1)
next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0) # greedy decoding
generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
# Stop if EOS token generated
if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
break
with torch.no_grad():
output = model(
decoder_input_ids=torch.tensor([[generated_sequence.squeeze()[-1]]]),
past_key_values=past_key_values,
encoder_outputs=encoder_outputs
)
summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)