pythonmachine-learningdeep-learningpytorchimage-augmentation

Plot the transformed (augmented) images in pytorch


I want to use one of the image augmentation techniques (for example rotation or horizontal flip) and apply it to some images of the CIFAR-10 dataset and plot them in PyTorch.

I know that we can use the following code to augmented images:

from torchvision import models, datasets, transforms
from torchvision.datasets import CIFAR10

data_transforms = transforms.Compose([
        # add augmentations
        transforms.RandomHorizontalFlip(p=0.5),
        # The output of torchvision datasets are PILImage images of range [0, 1].
        # We transform them to Tensors of normalized range [-1, 1]
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])

and then I used the transforms above when I want to load the Cifar10 dataset:

train_set = CIFAR10(
    root='./data/',
    train=True,
    download=True,
    transform=data_transforms['train'])

As far as I know, when this code is used, all CIFAR10 datasets are transformed.

Question

My question is how can I use data transform or augmentation techniques for some images in data sets and plot them? for example 10 images and their augmented images.


Solution

  • when this code is used, all CIFAR10 datasets are transformed

    Actually, the transform pipeline will only be called when images in the dataset are fetched via the __getitem__ function by the user or through a data loader. So at this point in time, train_set doesn't contain augmented images, they are transformed on the fly.


    You will need to construct another dataset without augmentations.

    >>> non_augmented = CIFAR10(
    ...     root='./data/',
    ...     train=True,
    ...     download=True)
    
    >>> train_set = CIFAR10(
    ...     root='./data/',
    ...     train=True,
    ...     download=True,
    ...     transform=data_transforms)
    

    Stack some images together:

    >>> imgs = torch.stack((*[non_augmented[i][0] for i in range(10)],
                            *[train_set[i][0] for i in range(10)]))
    
    >>> imgs.shape
    torch.Size([20, 3, 32, 32])
    

    Then torchvision.utils.make_grid can be useful to create the desired layout:

    >>> grid = torchvision.utils.make_grid(imgs, nrow=10)
    

    There you have it!

    >>> transforms.ToPILImage()(grid)
    

    enter image description here