pythonpytorchalbumentations

Use custom transformer for albumentations


I want to use the following custom albumentation transformer

import albumentations as A
from albumentations.pytorch import ToTensorV2

class RandomTranslateWithReflect:
    """Translate image randomly
    Translate vertically and horizontally by n pixels where
    n is integer drawn uniformly independently for each axis
    from [-max_translation, max_translation].
    Fill the uncovered blank area with reflect padding.
    """

    def __init__(self, max_translation):
        self.max_translation = max_translation

    def __call__(self, old_image):
        xtranslation, ytranslation = np.random.randint(-self.max_translation,
                                                       self.max_translation + 1,
                                                       size=2)
        # Apply the translation using the Albumentations library
        transform = A.ShiftScaleRotate(shift_limit=(xtranslation / old_image.shape[1], ytranslation / old_image.shape[0]),
                                       scale_limit=0,
                                       rotate_limit=0,
                                       border_mode=cv2.BORDER_REFLECT,
                                       p=1)
        new_image = transform(image=old_image)["image"]
        return new_image

I use the following code to put it in compose :

train_transform = {
    'cifar10': A.Compose([
        A.Lambda(image=RandomTranslateWithReflect(4)),
        A.HorizontalFlip(p=0.5),
        A.Normalize(*meanstd['cifar10']),
        ToTensorV2()
    ])
}

but I get this error:

mg1 = transform(image=img)["image"]
  File "/home/student/anaconda3/envs/few-shot/lib/python3.6/site-packages/albumentations/core/composition.py", line 205, in __call__
    data = t(**data)
  File "/home/student/anaconda3/envs/few-shot/lib/python3.6/site-packages/albumentations/core/transforms_interface.py", line 118, in __call__
    return self.apply_with_params(params, **kwargs)
  File "/home/student/anaconda3/envs/few-shot/lib/python3.6/site-packages/albumentations/core/transforms_interface.py", line 131, in apply_with_params
    res[key] = target_function(arg, **dict(params, **target_dependencies))
  File "/home/student/anaconda3/envs/few-shot/lib/python3.6/site-packages/albumentations/augmentations/transforms.py", line 1648, in apply
    return fn(img, **params)
TypeError: __call__() got an unexpected keyword argument 'cols'

Solution

  • I figured out how can I make custom transformation and use it

    class RandomTranslateWithReflect(ImageOnlyTransform):
        """Translate image randomly
        Translate vertically and horizontally by n pixels where
        n is integer drawn uniformly independently for each axis
        from [-max_translation, max_translation].
        Fill the uncovered blank area with reflect padding.
        """
    
        def __init__(self, max_translation, always_apply=False, p=1):
            super().__init__(always_apply, p)
            self.max_translation = max_translation
    
        def apply(self, old_image, **params):
            xtranslation, ytranslation = np.random.randint(-self.max_translation,
                                                           self.max_translation + 1,
                                                           size=2)
            # Apply the translation using the Albumentations library
            transform = A.ShiftScaleRotate(shift_limit=(xtranslation / old_image.shape[1], ytranslation / old_image.shape[0]),
                                           scale_limit=0,
                                           rotate_limit=0,
                                           border_mode=cv2.BORDER_REFLECT,
                                           p=1)
            new_image = transform(image=old_image)["image"]
            return new_image
    
    # Create the Albumentations Compose pipeline
    train_transform = A.Compose([
        RandomTranslateWithReflect(4),  # Use the custom class directly
        A.HorizontalFlip(p=0.5),
        A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        ToTensorV2()
    ])