I'm trying to split one of the Pytorch custom datasets (MNIST) into a training set and a validation set as follows:
def get_train_valid_splits(data_dir,
batch_size,
random_seed=1,
valid_size=0.2,
shuffle=True,
num_workers=4,
pin_memory=False):
normalize = transforms.Normalize((0.1307,), (0.3081,)) # MNIST
# define transforms
valid_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
train_transform = transforms.Compose([
transforms.ToTensor(),
normalize
])
# load the dataset
train_dataset = datasets.MNIST(root=data_dir, train=True,
download=True, transform=train_transform)
valid_dataset = datasets.MNIST(root=data_dir, train=True,
download=True, transform=valid_transform)
dataset_size = len(train_dataset)
indices = list(range(dataset_size))
split = int(np.floor(valid_size * dataset_size))
if shuffle == True:
np.random.seed(random_seed)
np.random.shuffle(indices)
train_idx, valid_idx = indices[split:], indices[:split]
train_sampler = sampler.SubsetRandomSampler(train_idx)
valid_sampler = sampler.SubsetRandomSampler(valid_idx)
print(len(train_sampler))
print(len(valid_sampler))
train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=batch_size, sampler=train_sampler,
num_workers=num_workers, pin_memory=pin_memory)
valid_loader = torch.utils.data.DataLoader(valid_dataset,
batch_size=batch_size, sampler=valid_sampler,
num_workers=num_workers, pin_memory=pin_memory)
print(len(train_loader.dataset))
print(len(valid_loader.dataset))
return (train_loader, valid_loader)
After calling the function I notice that the results of the indices to sample look right, 48000 and 12000:
print(len(train_sampler))
print(len(valid_sampler))
But when I look at the length of the data set associated with train_loader and valid_loader:
print(len(train_loader.dataset))
print(len(valid_loader.dataset))
I get the same length for both: 60000! Any idea what is going on here? Why is it giving the same length for both, even though I clearly split it by indices?
It's because the dataloader doesn't modify the dataset you pass it, but "applies" things like batch size, samplers, etc ... to the data when you try to access by iterating it. Your issue is len(loader.dataset)
, which gives you the length of the provided dataset without modification, when you really wanted len(loader)
which is the length of the dataset after "applying" things like batch size and samplers.
import torch
import numpy as np
dataset = np.random.rand(100,200)
sampler = torch.utils.data.SubsetRandomSampler(list(range(70)))
loader = torch.utils.data.DataLoader(dataset, sampler=sampler)
print(len(loader))
>>> 70
print(len(loader.dataset))
>>> 100
Note: The result of len will be affected by batch size:
# with batch size
loader = torch.utils.data.DataLoader(dataset, sampler=sampler, batch_size=2)
print(len(loader))
>>> 35
print(len(loader.dataset))
>>> 100