pythonvalidationpytorchdatasetsave

Save to disk training dataset and validation dataset separately in PyTorch


I want to save train dataset, test dataset, and validation dataset in 3 separate folders.

Doing this for training and testing is easy

# Get training and testing data
all_training_data = getattr(datasets, config['Pytorch_Dataset']['dataset'])(
    root= path_to_data + "/data_all_train_" + config['Pytorch_Dataset']['dataset'],
    train=True,
    download=True, # If present it will not download the data again
    transform=ToTensor()
)
test_data = getattr(datasets, config['Pytorch_Dataset']['dataset'])(
    root= path_to_data + "/data_test_" + config['Pytorch_Dataset']['dataset'],
    train=False,
    download=True, # If present it will not download the data again
    transform=ToTensor()
)

This code makes use of torchvision.datasets to load and save to disk the dataset specified in config['Pytorch_Dataset']['dataset'] (e.g. MNIST). However there is no option to load a validation set this way, there is no validation=True option.

I could split the train dataset into train and validation with torch.utils.data.random.split, but there are two main problems with this approach:


Solution

  • You don't have to save the split results on separate folders to maintain reproducibility, which is what I am assuming you really care for.

    You could instead fix the seed before calling split like this:

    torch.manual_seed(42)
    data_train, data_val = torch.utils.data.random_split(data_all_train, (0.7, 0.3))
    

    Then you get to maintain just the initial folders while also ensuring the train and val splits are consistent across trials.


    But the caveat to the above is, you are fixing the global seed so you are also losing the randomness you might desire in the dataloader shuffling and such, which will end up identical per trial.

    To avoid that, you can narrow down the scope you are fixing the seed by setting it only for the generator you pass to the split call:

    split_gen = torch.Generator()
    split_gen.manual_seed(42)
    
    data_train, data_val = torch.utils.data.random_split(
                                 data_all_train, 
                                 (0.7, 0.3), 
                                 generator=split_gen)