pythonpytorchnlpbatch-processingdataloader

Dataloader/sampler/collator to create batches based on the sample contents (sequence length)


I am converting someone else's code into a neater torch-y pipeline, using datasets and dataloaders, collate functions and samplers. While I have done such work before, I am not sure how to tackle the following problem.

The dataset contains sentences as samples. Every samples therefore has a number of words (or tokens), which we can get by naively splitting the sample on white space (sample.split()). Such a dummy dataset can look like this:

from random import randint

from torch.utils.data import Dataset


class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]

Now I want to be able to load data so that the max. number of tokens in a batch is not more than 250. That implies that the batch size can differ between iterations. One batch may contain two samples that have no more than 250 tokens in total (for instance 127 + 77) and another can have three (66+66+66). Now, the core functionality for this is rather straightforward. Full example below; not optimized by sorting on length or something but that's okay for this example.

The question is, how can I integrate this in the PyTorch eco-system? Batch sizes are so often used to indicate the number of samples (like in the dataloader). So where should I plug this in, or what should I subclass, to make this work like a regular dataloader?

from random import randint

from torch.utils.data import Dataset

class DummyDataset(Dataset):
    def __init__(self):
        data = []
        for _ in range(128):
            data.append("hello " * randint(64, 176))
        self.data = data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx: int):
        return self.data[idx]


if __name__ == '__main__':
    dataset = DummyDataset()

    def get_batch(max_tokens: int = 250):
        data_idxs = list(range(len(dataset)))

        batch = []
        total_batch_len = 0
        while data_idxs:
            sample = dataset[data_idxs[0]]
            sample_len = len(sample.split())

            if total_batch_len + sample_len <= max_tokens:
                batch.append(sample)
                total_batch_len += sample_len
                data_idxs.pop(0)
            elif batch:
                yield batch
                batch = []
                total_batch_len = 0

        yield batch

    # Sanity check that we indeed get all items from the dataset
    num_samples = 0
    num_batches = 0
    for b in get_batch():
        num_samples += len(b)
        num_batches += 1

    print(f"Created {num_batches} batches")
    assert num_samples == len(dataset)

Maybe torchtext's Iterator and its batch_size_fn can help but I have no experience with it (where should I add it; is it a dataloader itself or should I still wrap a dataloader around it, etc.).


Solution

  • After reading some source code, it seems that you can just use any iterator in a Dataloader's batch_sampler. So the following works as expected.

    from random import randint
    
    from torch.utils.data import Dataset
    from torch.utils.data.dataloader import DataLoader
    
    
    class DummyDataset(Dataset):
        def __init__(self):
            data = []
            for _ in range(128):
                data.append("hello " * randint(64, 176))
            self.data = data
    
        def __len__(self):
            return len(self.data)
    
        def __getitem__(self, idx: int):
            return self.data[idx]
    
    
    class TokenBatchSampler:
        def __init__(self, max_tokens: int = 250):
            self.max_tokens = max_tokens
            self.batches = []
            self._prepare_dataset()
    
        def __len__(self) -> int:
            return len(self.batches)
    
        def __iter__(self):
            return iter(self.batches)
    
        def _prepare_dataset(self):
            data_idxs = list(range(len(dataset)))
    
            batches = []
            batch_idxs = []
            total_batch_len = 0
            while data_idxs:
                sample_idx = data_idxs[0]
                sample = dataset[sample_idx]
                sample_len = len(sample.split())
    
                if total_batch_len + sample_len <= self.max_tokens:
                    batch_idxs.append(sample_idx)
                    total_batch_len += sample_len
                    data_idxs.pop(0)
                elif batch_idxs:
                    batches.append(batch_idxs)
                    batch_idxs = []
                    total_batch_len = 0
    
            batches.append(batch_idxs)
    
            self.batches = batches
    
    
    if __name__ == "__main__":
        dataset = DummyDataset()
    
        sampler = TokenBatchSampler()
        dataloader = DataLoader(dataset, batch_sampler=sampler)
        # Sanity check that we indeed get all items from the dataset
        for epoch in range(3):
            num_samples = 0
            num_batches = 0
            for b in dataloader:
                num_samples += len(b)
                num_batches += 1
    
            print(f"Created {num_batches} batches in epoch {epoch}")
            assert num_samples == len(dataset)
    
        print(f"DataLoader length {len(dataloader)}")