pythonrotationpytorchimage-rotationrotational-matrices

How do I rotate a PyTorch image tensor around it's center in a way that supports autograd?


I'd like to randomly rotate an image tensor (B, C, H, W) around it's center (2d rotation I think?). I would like to avoid using NumPy and Kornia, so that I basically only need to import from the torch module. I'm also not using torchvision.transforms, because I need it to be autograd compatible. Essentially I'm trying to create an autograd compatible version of torchvision.transforms.RandomRotation() for visualization techniques like DeepDream (so I need to avoid artifacts as much as possible).

import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


# Somehow rotate tensor around it's center
def rotate_tensor(tensor, radians):
    ...
    return rotated_tensor

# Get a random angle within a specified range 
r_degrees = 5
angle_range = list(range(-r_degrees, r_degrees))
n = random.randint(angle_range[0], angle_range[len(angle_range)-1])

# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180


# test_tensor = preprocess_simple('path/to/file', (512,512))
test_tensor = torch.randn(1,3,512,512)


# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor, ang_rad)


# Optionally use this to check rotated image
# deprocess_simple(output_tensor, 'rotated_image.jpg')

Some example outputs of what I'm trying to accomplish:

First example of rotated image Second example of rotated image


Solution

  • So the grid generator and the sampler are sub-modules of the Spatial Transformer (JADERBERG, Max, et al.). These sub-modules are not trainable, they let you apply a learnable, as well as non-learnable, spatial transformation. Here I take these two submodules and use them to rotate an image by theta using PyTorch's functions torch.nn.functional.affine_grid and torch.nn.functional.affine_sample (these functions are implementations of the generator and the sampler, respectively):

    import torch
    import torch.nn.functional as F
    import numpy as np
    import matplotlib.pyplot as plt
    
    def get_rot_mat(theta):
        theta = torch.tensor(theta)
        return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
                             [torch.sin(theta), torch.cos(theta), 0]])
    
    
    def rot_img(x, theta, dtype):
        rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
        grid = F.affine_grid(rot_mat, x.size()).type(dtype)
        x = F.grid_sample(x, grid)
        return x
    
    
    #Test:
    dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
    #im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
    plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
    plt.figure()
    #Rotation by np.pi/2 with autograd support:
    rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
    plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)
    

    In the example above, assume we take our image, im, to be a dancing cat in a skirt: enter image description here

    rotated_im will be a 90-degrees CCW rotated dancing cat in a skirt:

    enter image description here

    And this is what we get if we call rot_img with theta eqauls to np.pi/4: enter image description here

    And the best part that it's differentiable w.r.t the input and has autograd support! Hooray!