pythonpytorchpytorch-dataloaderbatchsize

Change batch size using a list of Pytorch's data loader


During the training of my neural network model, I used a Pytorch's data loader to accelerate the training of the model. But instead of using a fixed batch size before updating the model's parameter, I have a list of different batch sizes that I want the data loader to use.

Example

train_dataset = TensorDataset(x_train, y_train) # x_train.shape (8400, 4)
dataloader_train = DataLoader(train_dataset, batch_size=64) # with fixed batch size of 64

What I want is a data loader that can use a list of batch size that is dynamic (not fixe)?

list_batch_size = [30, 60, 110, ..., 231] # with this list's sum being equal to x_train.shape[0] (8400) 

Solution

  • You can use a custom sampler (or batch sampler) for this.

    Here's a quick proof-of-concept for a sampler that takes custom batch sizes as an argument to return batch indices as such:

    class VariableBatchSampler(Sampler):
        def __init__(self, dataset_len: int, batch_sizes: list):
            self.dataset_len = dataset_len
            self.batch_sizes = batch_sizes
            self.batch_idx = 0
            self.start_idx = 0
            self.end_idx = self.batch_sizes[self.batch_idx]
            
        def __iter__(self):
            return self
           
        def __next__(self):
            if self.start_idx >= self.dataset_len:
                raise StopIteration()
     
            batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.int)
            self.start_idx += (self.end_idx - self.start_idx)
            self.batch_idx += 1
    
            try:
                self.end_idx += self.batch_sizes[self.batch_idx]
            except IndexError:
                self.end_idx = self.dataset_len
                 
            return batch_indices
    

    You can instantiate the sampler and use it as the sampler argument while instantiating the DataLoader e.g.:

    sampler = VariableBatchSampler(dataset_len=len(train_dataset), batch_sizes=[10, 20, 30, 40])
    data_loader = DataLoader(train_dataset, sampler=sampler)
    

    Note that, each element in the data_loader iterable would contain one extra dimension for the batch (as the default value for batch_size is 1 in DataLoader); you can either use unsqueeze(dim=0) to get rid of the extra dim. Or better pass the sampler as the batch_sampler argument:

    data_loader = DataLoader(train_dataset, batch_sampler=sampler)