I have this class of sampler that allows me to enter sample my data per different batch sizes.
class VaribleBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
def __iter__(self):
return self
def __next__(self):
if self.start_idx >= self.dataset_len:
raise StopIteration()
batch_indices = torch.arange(self.start_idx, self.end_idx, dtype=torch.long)
self.start_idx += (self.end_idx - self.start_idx)
self.batch_idx += 1
try:
self.end_idx += self.batch_sizes[self.batch_idx]
except IndexError:
self.end_idx = self.dataset_len
return batch_indices
But I can't manage to iterate it in an epoch loop. It only works for one epoch.
batch_sizes = [4, 10, 7, ..., 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VaribleBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)
for epoch in np.arange(1, max_epoch):
model.train()
for x_batch, y_batch in dataloader_train:
...
You raise the StopIteration exception, but forget to reset your indices for the next epoch! Thus, it automatically stops after one epoch.
I have extended your code snippets into a working code example (without typo) which should work the way you intended.
import torch
import numpy as np
from torch.utils.data import Sampler
from torch.utils.data import DataLoader, TensorDataset
class VariableBatchSampler(Sampler):
def __init__(self, dataset_len: int, batch_sizes: list):
self.dataset_len = dataset_len
self.batch_sizes = batch_sizes
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
def __iter__(self):
return self
def __next__(self):
if self.start_idx >= self.dataset_len:
self.batch_idx = 0
self.start_idx = 0
self.end_idx = self.batch_sizes[self.batch_idx]
raise StopIteration
batch_indices = list(range(self.start_idx, self.end_idx))
self.start_idx = self.end_idx
self.batch_idx += 1
try:
self.end_idx += self.batch_sizes[self.batch_idx]
except IndexError:
self.end_idx = self.dataset_len
return batch_indices
x_train = torch.randn(23)
y_train = torch.randint(0, 2, (23,))
batch_sizes = [4, 10, 7, 2]
train_dataset = TensorDataset(x_train, y_train)
sampler = VariableBatchSampler(dataset_len=len(x_train), batch_sizes=batch_sizes)
dataloader_train = DataLoader(train_dataset, sampler=sampler)
max_epoch = 4
for epoch in np.arange(1, max_epoch):
print("Epoch: ", epoch)
for x_batch, y_batch in dataloader_train:
print(x_batch.shape)
This outputs:
Epoch: 1
torch.Size([1, 4])
torch.Size([1, 10])
torch.Size([1, 7])
torch.Size([1, 2])
Epoch: 2
torch.Size([1, 4])
torch.Size([1, 10])
torch.Size([1, 7])
torch.Size([1, 2])
Epoch: 3
torch.Size([1, 4])
torch.Size([1, 10])
torch.Size([1, 7])
torch.Size([1, 2])