pythonimage-processingpytorchcomputer-visiontorchvision

PyTorch : How to apply the same random transformation to multiple image?


I am writing a simple transformation for a dataset which contains many pairs of images. As a data augmentation, I want to apply some random transformation for each pair but the images in that pair should be transformed in the same way. For example, given a pair of two images A and B, if A is flipped horizontally, B must be flipped horizontally as A. Then the next pair C and D should be differently transformed from A and B but C and D are transformed in the same way. I am trying that in the way below

import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image

img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")

transform = transforms.RandomChoice(
    [transforms.RandomHorizontalFlip(), 
     transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))

random.seed(1)
display(transform(img_c))
display(transform(img_d))

Yet、 the above code does not choose the same transformation and as I tested, it is dependent on the number of times transform is called.

Is there any way to force transforms.RandomChoice to use the same transform when specified?


Solution

  • Usually a workaround is to apply the transform on the first image, retrieve the parameters of that transform, then apply with a deterministic transform with those parameters on the remaining images. However, here RandomChoice does not provide an API to get the parameters of the applied transform since it involves a variable number of transforms. In those cases, I usually implement an overwrite to the original function.

    Looking at the torchvision implementation, it's as simple as:

    class RandomChoice(RandomTransforms):
        def __call__(self, img):
            t = random.choice(self.transforms)
            return t(img)
    

    Here are two possible solutions.

    1. You can either sample from the transform list on __init__ instead of on __call__:

      import random
      import torchvision.transforms as T
      
      class RandomChoice(torch.nn.Module):
          def __init__(self):
              super().__init__()
              self.t = random.choice(self.transforms)
      
          def __call__(self, img):
              return self.t(img)
      

      So you can do:

      transform = RandomChoice([
           T.RandomHorizontalFlip(), 
           T.RandomVerticalFlip()
      ])
      display(transform(img_a)) # both img_a and img_b will
      display(transform(img_b)) # have the same transform
      
      transform = RandomChoice([
          T.RandomHorizontalFlip(), 
          T.RandomVerticalFlip()
      ])
      display(transform(img_c)) # both img_c and img_d will
      display(transform(img_d)) # have the same transform
      

    1. Or better yet, transform the images in batch:

      import random
      import torchvision.transforms as T
      
      class RandomChoice(torch.nn.Module):
          def __init__(self, transforms):
             super().__init__()
             self.transforms = transforms
      
          def __call__(self, imgs):
              t = random.choice(self.transforms)
              return [t(img) for img in imgs]
      

      Which allows to do:

      transform = RandomChoice([
           T.RandomHorizontalFlip(), 
           T.RandomVerticalFlip()
      ])
      
      img_at, img_bt = transform([img_a, img_b])
      display(img_at) # both img_a and img_b will
      display(img_bt) # have the same transform
      
      img_ct, img_dt = transform([img_c, img_d])
      display(img_ct) # both img_c and img_d will
      display(img_dt) # have the same transform