I'd like to randomly rotate an image tensor (B, C, H, W) around it's center (2d rotation I think?). I would like to avoid using NumPy and Kornia, so that I basically only need to import from the torch module. I'm also not using torchvision.transforms
, because I need it to be autograd compatible. Essentially I'm trying to create an autograd compatible version of torchvision.transforms.RandomRotation()
for visualization techniques like DeepDream (so I need to avoid artifacts as much as possible).
import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image
# Load image
def preprocess_simple(image_name, image_size):
Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
image = Image.open(image_name).convert('RGB')
return Loader(image).unsqueeze(0)
# Save image
def deprocess_simple(output_tensor, output_name):
output_tensor.clamp_(0, 1)
Image2PIL = transforms.ToPILImage()
image = Image2PIL(output_tensor.squeeze(0))
image.save(output_name)
# Somehow rotate tensor around it's center
def rotate_tensor(tensor, radians):
...
return rotated_tensor
# Get a random angle within a specified range
r_degrees = 5
angle_range = list(range(-r_degrees, r_degrees))
n = random.randint(angle_range[0], angle_range[len(angle_range)-1])
# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180
# test_tensor = preprocess_simple('path/to/file', (512,512))
test_tensor = torch.randn(1,3,512,512)
# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor, ang_rad)
# Optionally use this to check rotated image
# deprocess_simple(output_tensor, 'rotated_image.jpg')
Some example outputs of what I'm trying to accomplish:
So the grid generator and the sampler are sub-modules of the Spatial Transformer (JADERBERG, Max, et al.). These sub-modules are not trainable, they let you apply a learnable, as well as non-learnable, spatial transformation.
Here I take these two submodules and use them to rotate an image by theta
using PyTorch's functions torch.nn.functional.affine_grid
and torch.nn.functional.affine_sample
(these functions are implementations of the generator and the sampler, respectively):
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
def get_rot_mat(theta):
theta = torch.tensor(theta)
return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
[torch.sin(theta), torch.cos(theta), 0]])
def rot_img(x, theta, dtype):
rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
grid = F.affine_grid(rot_mat, x.size()).type(dtype)
x = F.grid_sample(x, grid)
return x
#Test:
dtype = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)
In the example above, assume we take our image, im
, to be a dancing cat in a skirt:
rotated_im
will be a 90-degrees CCW rotated dancing cat in a skirt:
And this is what we get if we call rot_img
with theta
eqauls to np.pi/4
:
And the best part that it's differentiable w.r.t the input and has autograd support! Hooray!