I'm reading through a PyTorch tutorial for transfer learning and I'm having a hard time figuring out exactly what the following block is doing, with regards to generating the dataset:
# Data augmentation and normalization for training
# Just normalization for validation
data_transforms = {
'train': transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
'val': transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
]),
}
data_dir = 'data/hymenoptera_data'
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
class_names = image_datasets['train'].classes
Particularly, operations of form similar to (excuse any misuse of terminology - I'm still learning Python) {x: len(image_datasets[x]) for x in ['train', 'val']}.
It was explained to me before but I've since forgotten and I'm not sure how to phrase the question for a general internet search, so I'm asking here. I know that the code is defining a loop but the syntax is confusing me. Any clarification would be greatly appreciated.
I tried googling some stuff.
data_dir = 'data/hymenoptera_data'
Defines directory where dataset is located
image_datasets = {x: datasets.ImageFolder(os.path.join(data_dir, x),
data_transforms[x])
for x in ['train', 'val']}
This is a list comprehension where the keys are 'train' and 'val' (training and validation datasets), and the values are ImageFolder datasets for each of these sets.
ImageFolder(os.path.join(data_dir, x),data_transforms[x])
ImageFolder is a PyTorch function that loads images from a directory where each subfolder represents a class label, automatically assigns class labels, and applies the relevant transformations data_transforms[x]
which are defined above for the train and val sets. os.path.join(data_dir, x)
Constructs the full path to the image dataset directory to make sure the ImageFolder works correctly
dataloaders = {x: torch.utils.data.DataLoader(image_datasets[x], batch_size=4,
shuffle=True, num_workers=4)
for x in ['train', 'val']}
This is another dictionary comprehension where it loops over x in ['train', 'val'], meaning when x = 'train', it creates a DataLoader for the training dataset, and when x = 'val', it creates a DataLoader for the validation dataset. It then assigns the created DataLoader to the corresponding key ('train' or 'val').
torch.utils.data.DataLoader(image_datasets[x], batch_size=4, shuffle=True, num_workers=4)
Just defines the batch_size that will be used, shuffle=True defines whether to shuffle the data (important for training), and how many worker threads to load data in parallel (speeds up data loading)
dataset_sizes = {x: len(image_datasets[x]) for x in ['train', 'val']}
This is another dictionary comprehension that loops over x in ['train', 'val'], meaning when x = 'train', it defines the value of the dictionary entry as the length of the training dataset, and when x = 'val', it defines the size of the dataset as the length of the validation dataset.
class_names = image_datasets['train'].classes
This is a line of code that extracts the names of the classes from the image_datasets training set!
Hope this helps