huggingface-transformerssummarizationhuggingface

Summarization with Huggingface: How to generate one word at a time?


I am using a DistilBART for abstractive summarization. The method generate() is very straightforward to use. However, it returns complete, finished summaries. What I want is, at each step, access the logits to then get the list of next-word candidates and choose based on my own criteria. Once chosen, continue with the next word and so on until the EOS token is produced.

I am aware that I can access the logits by doing model(**input).logits[:, -1, :], but here the input would be the whole (encoded) text, so what would exactly these logits correspond with? The first generated token? The last?

Thank you for your answers!


Solution

  • For future reference, here is how it can be done (note: this is specific to encoder-decoder models, like BART):

    1. Initialization

    import torch
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    
    # Load model
    tokenizer = AutoTokenizer.from_pretrained("sshleifer/distilbart-xsum-1-1")
    model = AutoModelForSeq2SeqLM.from_pretrained("sshleifer/distilbart-xsum-1-1")
    
    text = "..."
    
    # Tokenize text
    batch = tokenizer(text, return_tensors="pt")
    

    2. Example 1: Summary generation with greedy decoding (no cache)

    generated_sequence = torch.tensor([[tokenizer.sep_token_id]])  # initial token
    
    # Generation loop
    while True:
        with torch.no_grad():
            output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
        next_token_logits = output.logits[:, -1, :]
        next_token_scores = next_token_logits.softmax(dim=-1)
    
        # Take token with highest probability
        next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)
    
        # Append token to generated sequence
        generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
        # Stop if EOS token generated
        if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
            break
    
    summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
    

    3. Example 2: Summary generation with top-k, top-p sampling & temperature (no cache)

    from transformers.generation_utils import top_k_top_p_filtering
    
    temperature = 0.7
    generated_sequence = torch.tensor([[tokenizer.sep_token_id]])  # initial token
    
    # Generation loop
    while True:
        with torch.no_grad():
            output = model(input_ids=batch["input_ids"], decoder_input_ids=generated_sequence)
        logits = output.logits[:, -1, :] / temperature  # apply temperature
        filtered_logits = top_k_top_p_filtering(logits=logits, top_k=4, top_p=0.7)
        probabilities = filtered_logits.softmax(dim=-1)
    
        # Sample next token
        next_token = torch.multinomial(probabilities, 1)
    
        # Append token to generated sequence
        generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
        # Stop if EOS token generated
        if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
            break
    
    summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)
    

    (Other generating strategies would be analogous).


    4. Using cache

    Since the input to the encoder (i.e., the text to be summarized) is always the same, we can cache it to greatly speed up the generation.

    generated_sequence = torch.tensor([[tokenizer.sep_token_id]])  # initial token
    input_ids = batch["input_ids"]
    past_key_values = None
    
    with torch.no_grad():
        output = model(
            input_ids=input_ids,
            decoder_input_ids=generated_sequence,
            past_key_values=past_key_values
        )
        
    encoder_outputs=output.encoder_last_hidden_state
    
    # Generation loop
    while True:
        # From here on, use cached attention
        past_key_values = output.past_key_values
        next_token_logits = output.logits[:, -1, :]
        next_token_scores = next_token_logits.softmax(dim=-1)
        next_token = next_token_scores.argmax().unsqueeze(0).unsqueeze(0)  # greedy decoding
        generated_sequence = torch.cat((generated_sequence, next_token), dim=1)
        # Stop if EOS token generated
        if (generated_sequence.squeeze()[-1] == tokenizer.eos_token_id):
            break
        with torch.no_grad():
            output = model(
                decoder_input_ids=torch.tensor([[generated_sequence.squeeze()[-1]]]),
                past_key_values=past_key_values,
                encoder_outputs=encoder_outputs
            )
    
    summary = tokenizer.batch_decode(generated_sequence, skip_special_tokens=True)