pythonpytorchpytorch-dataloader

Iteration not working when using a sampler in Pytorch's data loader


I have this class of sampler that allows me to enter sample my data per different batch sizes.

class VaribleBatchSampler(Sampler):
    def __init__(self, dataset_len: int, batch_sizes: list):
        self.dataset_len = dataset_len
        self.batch_sizes = batch_sizes
        self.batch_idx = 0
        self.start_idx = 0
        self.end_idx = self.batch_sizes[self.batch_idx]
        
    def __iter__(self):
        return self
       
    def __next__(self):
        if self.start_idx >= self.dataset_len:
            raise StopIteration()
 
        batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.long)
        self.start_idx += (self.end_idx - self.start_idx)
        self.batch_idx += 1

        try:
            self.end_idx += self.batch_sizes[self.batch_idx]
        except IndexError:
            self.end_idx = self.dataset_len    

        return batch_indices

But I can't manage to iterate it in an epoch loop. It only works for one epoch.

batch_sizes = [4, 10, 7, ..., 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VaribleBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)

for epoch in np.arange(1, max_epoch):
    model.train()
    for x_batch, y_batch in dataloader_train:
        ...

Solution

  • You raise the StopIteration exception, but forget to reset your indices for the next epoch! Thus, it automatically stops after one epoch.

    I have extended your code snippets into a working code example (without typo) which should work the way you intended.

    import torch
    import numpy as np
    from torch.utils.data import Sampler
    from torch.utils.data import DataLoader, TensorDataset
    
    
    
    class VariableBatchSampler(Sampler):
        def __init__(self, dataset_len: int, batch_sizes: list):
            self.dataset_len = dataset_len
            self.batch_sizes = batch_sizes
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
    
        def __iter__(self):
            return self
    
        def __next__(self):
            if self.start_idx >= self.dataset_len:
                self.batch_idx = 0
                self.start_idx = 0
                self.end_idx = self.batch_sizes[self.batch_idx]
                raise StopIteration
    
            batch_indices = list(range(self.start_idx, self.end_idx))
            self.start_idx = self.end_idx
            self.batch_idx += 1
    
            try:
                self.end_idx += self.batch_sizes[self.batch_idx]
            except IndexError:
                self.end_idx = self.dataset_len
    
            return batch_indices
    
    
    x_train = torch.randn(23)
    y_train = torch.randint(0, 2, (23,))
    
    batch_sizes = [4, 10, 7, 2]
    train_dataset = TensorDataset(x_train, y_train)
    sampler = VariableBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
    dataloader_train = DataLoader(train_dataset, sampler=sampler)
    
    max_epoch = 4
    for epoch in np.arange(1, max_epoch):
        print("Epoch: ", epoch)
        for x_batch, y_batch in dataloader_train:
             print(x_batch.shape)
    

    This outputs:

    Epoch: 1
    torch.Size([1, 4])
    torch.Size([1, 10])
    torch.Size([1, 7])
    torch.Size([1, 2])
    Epoch: 2
    torch.Size([1, 4])
    torch.Size([1, 10])
    torch.Size([1, 7])
    torch.Size([1, 2])
    Epoch: 3
    torch.Size([1, 4])
    torch.Size([1, 10])
    torch.Size([1, 7])
    torch.Size([1, 2])