pythonhuggingface-transformershuggingfacehuggingface-datasetshuggingface-trainer

How to use huggingface HF trainer train with custom collate function?


I have some custom data set with custom table entries and wanted to deal with it with a custom collate. But it didn't work when I pass a collate function I wrote (that DOES work on a individual dataloader e.g., see How does one create a pytorch data loader with a custom hugging face data set without having errors? or How does one create a pytoch data loader using an interleaved hugging face dataset?) . It just doesn't work with HF trianer.

Code

from pathlib import Path
# token = open(Path('~/data/hf_token.txt').expanduser()).read().strip()
token = None
batch_size = 8

# -- AF now
from datasets import load_dataset
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
if tokenizer.pad_token_id is None:
  tokenizer.pad_token = tokenizer.eos_token
model = GPT2LMHeadModel.from_pretrained("gpt2")
device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
model = model.to(device)

# -- Get batch from dataset
from datasets import load_dataset
# path, name = 'brando/debug1_af', 'debug1_af'
path, name = 'brando/debug0_af', 'debug0_af'
# train_dataset = load_dataset(path, name, streaming=True, split="train", token=token).with_format(type="torch")
# eval_dataset = load_dataset(path, name, streaming=True, split="test", token=token).with_format(type="torch")
# batch = dataset.take(1)
# column_names = next(iterbatch).keys()
# print(f'{column_names=}')

# -- Compute max steps (I think we should try to do this for real experiments such that the number of tokens is the same in all training runs for fair experiments, todo: ask Sudharsan or online, for now just make streaming=False)
train_dataset = load_dataset(path, name, streaming=False, split="train", token=token).with_format(type="torch")  # hack to get dataset size
eval_dataset = load_dataset(path, name, streaming=False, split="test", token=token).with_format(type="torch") # hack to get dataset size
print(f'{len(train_dataset)=}')
print(f'{len(eval_dataset)=}')
per_device_train_batch_size = batch_size
num_epochs = 1
max_steps = (len(train_dataset) // per_device_train_batch_size) * num_epochs
print(f'{max_steps=}')    

# -- Get trainer
def collate_tokenize(data):
    text_batch = [f'informal statement {example["generated informal statement"]} formal statement {example["formal statement"]}' for example in data]
    tokenized = tokenizer(text_batch, padding='longest', max_length=128, truncation=True, return_tensors='pt')
    return tokenized

from transformers import Trainer, TrainingArguments
training_args = TrainingArguments(
    output_dir=Path('./results').expanduser(),          # output directory
    max_steps=max_steps,             # max_steps
    per_device_train_batch_size=per_device_train_batch_size,   # batch size per device during training
    per_device_eval_batch_size=batch_size,    # batch size for evaluation
    warmup_steps=500,                # number of warmup steps for learning rate scheduler
    weight_decay=0.01,               # strength of weight decay
    logging_dir=Path('./logs').expanduser(),            # directory for storing logs
    logging_steps=10,
    report_to='none',
)
trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_dataset,         # training dataset
    eval_dataset=eval_dataset,             # evaluation dataset
    data_collator = collate_tokenize,
)
trainer.train()
print('Done!\a')

error:

len(train_dataset)=14
len(eval_dataset)=13
max_steps=1
/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
  warnings.warn(
---------------------------------------------------------------------------
IndexError                                Traceback (most recent call last)
<ipython-input-2-4403554fc52d> in <cell line: 63>()
     61     data_collator = collate_tokenize,
     62 )
---> 63 trainer.train()
     64 print('Done!\a')

11 frames
/usr/local/lib/python3.10/dist-packages/datasets/formatting/formatting.py in _check_valid_index_key(key, size)
    524     if isinstance(key, int):
    525         if (key < 0 and key + size < 0) or (key >= size):
--> 526             raise IndexError(f"Invalid key: {key} is out of bounds for size {size}")
    527         return
    528     elif isinstance(key, slice):

IndexError: Invalid key: 12 is out of bounds for size 0

why? How to fix?


Solution

  • There are a couple of issues with your code that might interfere with the HF trainer class. Here's some changes I made:

    There are some additional suggestions here as well.

    If you run into other issues, you can always set the logging info like this:

    import transformers
    transformers.logging.set_verbosity_info()
    

    Here's the working code:

    from pathlib import Path
    from datasets import load_dataset
    import torch
    from transformers import GPT2LMHeadModel, PreTrainedTokenizer, AutoTokenizer, Trainer, TrainingArguments
    
    # Load model and tokenizer
    model = GPT2LMHeadModel.from_pretrained("gpt2")
    device = torch.device(f"cuda:{0}" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    tokenizer = AutoTokenizer.from_pretrained("gpt2")
    
    # Ensure padding token is set
    tokenizer.pad_token = tokenizer.eos_token
    if tokenizer.pad_token_id is None:
        raise ValueError("Padding token is not set.")
    
    # Load datasets
    path, name = 'brando/debug0_af', 'debug0_af'
    train_dataset = load_dataset(path, name, streaming=False, split="train").with_format(type="torch")
    eval_dataset = load_dataset(path, name, streaming=False, split="test").with_format(type="torch")
    
    # Compute max steps
    batch_size = 3
    print(f'{len(train_dataset)=}')
    print(f'{len(eval_dataset)=}')
    per_device_train_batch_size = batch_size
    num_epochs = 1
    max_steps = 8
    print(f'{max_steps=}')
    
    # Define custom collate function
    from typing import List, Dict
    from transformers import PreTrainedTokenizer
    
    def custom_collate_fn(data: List[Dict[str, str]], tokenizer: PreTrainedTokenizer) -> Dict[str, torch.Tensor]:
        # Ensure tokenizer has a padding token
        if tokenizer.pad_token is None:
            tokenizer.pad_token = tokenizer.eos_token
    
        # Extract and concatenate informal and formal statements
        sequences = []
        for idx, example in enumerate(data):
            # Handle null values
            informal = example.get("generated informal statement", "") or ""
            formal = example.get("formal statement", "") or ""
    
            # Skip if both are empty
            if not informal and not formal:
                continue
    
            sequences.append(f'informal statement {informal} formal statement {formal}')
    
        # Tokenize the sequences
        tokenized_data = tokenizer(sequences, padding='longest', truncation=True, return_tensors='pt')
        tokenized_data["labels"] = tokenized_data["input_ids"].clone()
    
        return tokenized_data
    
    # Training arguments and trainer instantiation
    training_args = TrainingArguments(
        output_dir=Path('./results').expanduser(),
        max_steps=max_steps,
        per_device_train_batch_size=per_device_train_batch_size,
        per_device_eval_batch_size=batch_size,
        warmup_steps=500,
        weight_decay=0.01,
        logging_dir=Path('./logs').expanduser(),
        logging_steps=10,
        remove_unused_columns=False,
        report_to='none',
    )
    
    
    sample_data = [train_dataset[i] for i in range(batch_size)]
    processed_data = custom_collate_fn(sample_data, tokenizer=tokenizer)
    
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        data_collator=lambda data: custom_collate_fn(data, tokenizer=tokenizer)
    )
    
    trainer.train()
    print('Done!\a')
    
    

    And a colab with some stuff to check the results.