Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do I rotate a PyTorch image tensor around it's center in a way that supports autograd?

I'd like to randomly rotate an image tensor (B, C, H, W) around it's center (2d rotation I think?). I would like to avoid using NumPy and Kornia, so that I basically only need to import from the torch module. I'm also not using torchvision.transforms, because I need it to be autograd compatible. Essentially I'm trying to create an autograd compatible version of torchvision.transforms.RandomRotation() for visualization techniques like DeepDream (so I need to avoid artifacts as much as possible).

import torch
import math
import random
import torchvision.transforms as transforms
from PIL import Image


# Load image
def preprocess_simple(image_name, image_size):
    Loader = transforms.Compose([transforms.Resize(image_size), transforms.ToTensor()])
    image = Image.open(image_name).convert('RGB')
    return Loader(image).unsqueeze(0)
    
# Save image   
def deprocess_simple(output_tensor, output_name):
    output_tensor.clamp_(0, 1)
    Image2PIL = transforms.ToPILImage()
    image = Image2PIL(output_tensor.squeeze(0))
    image.save(output_name)


# Somehow rotate tensor around it's center
def rotate_tensor(tensor, radians):
    ...
    return rotated_tensor

# Get a random angle within a specified range 
r_degrees = 5
angle_range = list(range(-r_degrees, r_degrees))
n = random.randint(angle_range[0], angle_range[len(angle_range)-1])

# Convert angle from degrees to radians
ang_rad = angle * math.pi / 180


# test_tensor = preprocess_simple('path/to/file', (512,512))
test_tensor = torch.randn(1,3,512,512)


# Rotate input tensor somehow
output_tensor = rotate_tensor(test_tensor, ang_rad)


# Optionally use this to check rotated image
# deprocess_simple(output_tensor, 'rotated_image.jpg')

Some example outputs of what I'm trying to accomplish:

First example of rotated image Second example of rotated image

like image 457
ProGamerGov Avatar asked Oct 04 '20 17:10

ProGamerGov


2 Answers

So the grid generator and the sampler are sub-modules of the Spatial Transformer (JADERBERG, Max, et al.). These sub-modules are not trainable, they let you apply a learnable, as well as non-learnable, spatial transformation. Here I take these two submodules and use them to rotate an image by theta using PyTorch's functions torch.nn.functional.affine_grid and torch.nn.functional.affine_sample (these functions are implementations of the generator and the sampler, respectively):

import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt

def get_rot_mat(theta):
    theta = torch.tensor(theta)
    return torch.tensor([[torch.cos(theta), -torch.sin(theta), 0],
                         [torch.sin(theta), torch.cos(theta), 0]])


def rot_img(x, theta, dtype):
    rot_mat = get_rot_mat(theta)[None, ...].type(dtype).repeat(x.shape[0],1,1)
    grid = F.affine_grid(rot_mat, x.size()).type(dtype)
    x = F.grid_sample(x, grid)
    return x


#Test:
dtype =  torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
#im should be a 4D tensor of shape B x C x H x W with type dtype, range [0,255]:
plt.imshow(im.squeeze(0).permute(1,2,0)/255) #To plot it im should be 1 x C x H x W
plt.figure()
#Rotation by np.pi/2 with autograd support:
rotated_im = rot_img(im, np.pi/2, dtype) # Rotate image by 90 degrees.
plt.imshow(rotated_im.squeeze(0).permute(1,2,0)/255)

In the example above, assume we take our image, im, to be a dancing cat in a skirt: enter image description here

rotated_im will be a 90-degrees CCW rotated dancing cat in a skirt:

enter image description here

And this is what we get if we call rot_img with theta eqauls to np.pi/4: enter image description here

And the best part that it's differentiable w.r.t the input and has autograd support! Hooray!

like image 72
Gil Pinsky Avatar answered Sep 24 '22 10:09

Gil Pinsky


There is a pytorch function for that:

x = torch.tensor([[0, 1],
            [2, 3]])

x = torch.rot90(x, 1, [0, 1])
>> tensor([[1, 3],
           [0, 2]])

Here are the docs: https://pytorch.org/docs/stable/generated/torch.rot90.html

like image 29
Theodor Peifer Avatar answered Sep 25 '22 10:09

Theodor Peifer