I am writing a simple transformation for a dataset which contains many pairs of images. As a data augmentation, I want to apply some random transformation for each pair but the images in that pair should be transformed in the same way.
For example, given a pair of two images A
and B
, if A
is flipped horizontally, B
must be flipped horizontally as A
. Then the next pair C
and D
should be differently transformed from A
and B
but C
and D
are transformed in the same way. I am trying that in the way below
import random
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
img_a = Image.open("sample_ajpg") # note that two images have the same size
img_b = Image.open("sample_b.png")
img_c, img_d = Image.open("sample_c.jpg"), Image.open("sample_d.png")
transform = transforms.RandomChoice(
[transforms.RandomHorizontalFlip(),
transforms.RandomVerticalFlip()]
)
random.seed(0)
display(transform(img_a))
display(transform(img_b))
random.seed(1)
display(transform(img_c))
display(transform(img_d))
Yet、 the above code does not choose the same transformation and as I tested, it is dependent on the number of times transform
is called.
Is there any way to force transforms.RandomChoice
to use the same transform when specified?
Usually a workaround is to apply the transform on the first image, retrieve the parameters of that transform, then apply with a deterministic transform with those parameters on the remaining images. However, here RandomChoice
does not provide an API to get the parameters of the applied transform since it involves a variable number of transforms.
In those cases, I usually implement an overwrite to the original function.
Looking at the torchvision implementation, it's as simple as:
class RandomChoice(RandomTransforms):
def __call__(self, img):
t = random.choice(self.transforms)
return t(img)
Here are two possible solutions.
You can either sample from the transform list on __init__
instead of on __call__
:
import random
import torchvision.transforms as T
class RandomChoice(torch.nn.Module):
def __init__(self):
super().__init__()
self.t = random.choice(self.transforms)
def __call__(self, img):
return self.t(img)
So you can do:
transform = RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
display(transform(img_a)) # both img_a and img_b will
display(transform(img_b)) # have the same transform
transform = RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
display(transform(img_c)) # both img_c and img_d will
display(transform(img_d)) # have the same transform
Or better yet, transform the images in batch:
import random
import torchvision.transforms as T
class RandomChoice(torch.nn.Module):
def __init__(self, transforms):
super().__init__()
self.transforms = transforms
def __call__(self, imgs):
t = random.choice(self.transforms)
return [t(img) for img in imgs]
Which allows to do:
transform = RandomChoice([
T.RandomHorizontalFlip(),
T.RandomVerticalFlip()
])
img_at, img_bt = transform([img_a, img_b])
display(img_at) # both img_a and img_b will
display(img_bt) # have the same transform
img_ct, img_dt = transform([img_c, img_d])
display(img_ct) # both img_c and img_d will
display(img_dt) # have the same transform