Image Batch Affine Transformation with Kornia

https://en.wikipedia.org/wiki/Affine_transformation

In [1]:
import matplotlib
import matplotlib.pyplot as plt
import kornia

from PIL import Image
import numpy as np
import torch, torchvision
from torchvision.transforms import functional_pil  
from torchvision.transforms.functional import _get_inverse_affine_matrix

from kornia.geometry.transform import warp_affine

Read and show image

In [2]:
imsize = 224
im1 = Image.open('../data/dog.jpg').resize((imsize, imsize))
im2 = Image.open('../data/dog1.jpg').resize((imsize, imsize))
im3 = Image.open('../data/dog2.jpg').resize((imsize, imsize))
fig, axes = plt.subplots(1, 3, figsize=(9, 3))
axes[0].imshow(im1)
axes[1].imshow(im2)
axes[2].imshow(im3)
Out[2]:
<matplotlib.image.AxesImage at 0x7f5b235c1c00>

Function to compute transform matrix

In [3]:
def generate_m(n=1, imgsize=224):
    M = []
    center = [imsize//2, imsize//2]
    for i in range(n):
        M_ = _get_inverse_affine_matrix(center=center,
                                   angle=np.random.uniform(-20, 20),
                                   translate=(0, 0),
                                   scale=np.random.uniform(0.7, 1.4),
                                   shear=[np.random.uniform(-20, 20), np.random.uniform(-10, 10)])
        
        M_ = torch.tensor(M_).reshape(1, 2, 3)
        M.append(M_)
    return torch.cat(M, dim=0)

Apply two different transformations to original images

$$ data_1 = data\cdot M_1 \\ data_2 = data\cdot M_2 $$

$M_1$ and $M_2$ are transformation Matrice

Get two transformation matrice

In [4]:
data = [torch.tensor(np.array(i).transpose(2, 0, 1)).unsqueeze(dim=0) for i in (im1, im2, im3)]
data = torch.cat(data, dim=0).float()

N = data.size(0)

M1 = generate_m(data.size(0))
M2 = generate_m(data.size(0))

Apply transformation, show affine transformed images and corresponding keypoints.

In [5]:
data1 = warp_affine(data, M1, dsize=data.shape[-2:]).numpy().transpose((0, 2, 3, 1))
data2 = warp_affine(data, M2, dsize=data.shape[-2:]).numpy().transpose((0, 2, 3, 1))


fig, axes = plt.subplots(3, N, figsize=(N*4, 3*4))
for ax in axes.flatten():
    ax.axis('off')

xys = np.random.uniform(10, imsize-10, size=(N, 2, 10)).astype(np.float32)
colors = np.random.uniform(0, 1, size=(10, 3))

for i in range(N):
    axes[0, i].imshow(data.numpy().transpose((0, 2, 3, 1))[i].astype(np.uint8))
    axes[0, i].scatter(xys[i, 0, :], xys[i, 1, :], color=colors, marker='x')

    axes[1, i].imshow(data1[i].astype(np.uint8))
    xys1 = torch.tensor(xys)
    xys1 = torch.cat((xys1, torch.ones(N, 1, xys1.size(2))), dim=1)
    xys1 = torch.bmm(torch.cat((M1, torch.tensor([0, 0, 1]).view(1, 1, -1).repeat(N, 1, 1)), dim=1), xys1)[:, :-1, :]
    axes[1, i].scatter(xys1[i, 0, :], xys1[i, 1, :], color=colors, marker='x')
    
    axes[2, i].imshow(data2[i].astype(np.uint8))
    xys2 = torch.tensor(xys)
    xys2 = torch.cat((xys2, torch.ones(N, 1, xys2.size(2))), dim=1)
    xys2 = torch.bmm(torch.cat((M2, torch.tensor([0, 0, 1]).view(1, 1, -1).repeat(N, 1, 1)), dim=1), xys2)[:, :-1, :]
    axes[2, i].scatter(xys2[i, 0, :], xys2[i, 1, :], color=colors, marker='x')