I want to save train dataset, test dataset, and validation dataset in 3 separate folders.
Doing this for training and testing is easy
# Get training and testing data
all_training_data = getattr(datasets, config['Pytorch_Dataset']['dataset'])(
root= path_to_data + "/data_all_train_" + config['Pytorch_Dataset']['dataset'],
train=True,
download=True, # If present it will not download the data again
transform=ToTensor()
)
test_data = getattr(datasets, config['Pytorch_Dataset']['dataset'])(
root= path_to_data + "/data_test_" + config['Pytorch_Dataset']['dataset'],
train=False,
download=True, # If present it will not download the data again
transform=ToTensor()
)
This code makes use of torchvision.datasets
to load and save to disk the dataset specified in config['Pytorch_Dataset']['dataset']
(e.g. MNIST). However there is no option to load a validation set this way, there is no validation=True
option.
I could split the train dataset into train and validation with torch.utils.data.random.split
, but there are two main problems with this approach:
data_all_train
, I want to save only 2 folders, one with the true training part and one with the validation partdata_train
and data_validation
are present, and in this case it should not download again data_all_train
, even if not presentYou don't have to save the split results on separate folders to maintain reproducibility, which is what I am assuming you really care for.
You could instead fix the seed before calling split like this:
torch.manual_seed(42)
data_train, data_val = torch.utils.data.random_split(data_all_train, (0.7, 0.3))
Then you get to maintain just the initial folders while also ensuring the train and val splits are consistent across trials.
But the caveat to the above is, you are fixing the global seed so you are also losing the randomness you might desire in the dataloader shuffling and such, which will end up identical per trial.
To avoid that, you can narrow down the scope you are fixing the seed by setting it only for the generator you pass to the split call:
split_gen = torch.Generator()
split_gen.manual_seed(42)
data_train, data_val = torch.utils.data.random_split(
data_all_train,
(0.7, 0.3),
generator=split_gen)