import matplotlib
import matplotlib.pyplot as plt
import kornia
from PIL import Image
import numpy as np
import torch, torchvision
from torchvision.transforms import functional_pil
from torchvision.transforms.functional import _get_inverse_affine_matrix
from kornia.geometry.transform import warp_affine
imsize = 224
im1 = Image.open('../data/dog.jpg').resize((imsize, imsize))
im2 = Image.open('../data/dog1.jpg').resize((imsize, imsize))
im3 = Image.open('../data/dog2.jpg').resize((imsize, imsize))
fig, axes = plt.subplots(1, 3, figsize=(9, 3))
axes[0].imshow(im1)
axes[1].imshow(im2)
axes[2].imshow(im3)
def generate_m(n=1, imgsize=224):
M = []
center = [imsize//2, imsize//2]
for i in range(n):
M_ = _get_inverse_affine_matrix(center=center,
angle=np.random.uniform(-20, 20),
translate=(0, 0),
scale=np.random.uniform(0.7, 1.4),
shear=[np.random.uniform(-20, 20), np.random.uniform(-10, 10)])
M_ = torch.tensor(M_).reshape(1, 2, 3)
M.append(M_)
return torch.cat(M, dim=0)
$M_1$ and $M_2$ are transformation Matrice
data = [torch.tensor(np.array(i).transpose(2, 0, 1)).unsqueeze(dim=0) for i in (im1, im2, im3)]
data = torch.cat(data, dim=0).float()
N = data.size(0)
M1 = generate_m(data.size(0))
M2 = generate_m(data.size(0))
data1 = warp_affine(data, M1, dsize=data.shape[-2:]).numpy().transpose((0, 2, 3, 1))
data2 = warp_affine(data, M2, dsize=data.shape[-2:]).numpy().transpose((0, 2, 3, 1))
fig, axes = plt.subplots(3, N, figsize=(N*4, 3*4))
for ax in axes.flatten():
ax.axis('off')
xys = np.random.uniform(10, imsize-10, size=(N, 2, 10)).astype(np.float32)
colors = np.random.uniform(0, 1, size=(10, 3))
for i in range(N):
axes[0, i].imshow(data.numpy().transpose((0, 2, 3, 1))[i].astype(np.uint8))
axes[0, i].scatter(xys[i, 0, :], xys[i, 1, :], color=colors, marker='x')
axes[1, i].imshow(data1[i].astype(np.uint8))
xys1 = torch.tensor(xys)
xys1 = torch.cat((xys1, torch.ones(N, 1, xys1.size(2))), dim=1)
xys1 = torch.bmm(torch.cat((M1, torch.tensor([0, 0, 1]).view(1, 1, -1).repeat(N, 1, 1)), dim=1), xys1)[:, :-1, :]
axes[1, i].scatter(xys1[i, 0, :], xys1[i, 1, :], color=colors, marker='x')
axes[2, i].imshow(data2[i].astype(np.uint8))
xys2 = torch.tensor(xys)
xys2 = torch.cat((xys2, torch.ones(N, 1, xys2.size(2))), dim=1)
xys2 = torch.bmm(torch.cat((M2, torch.tensor([0, 0, 1]).view(1, 1, -1).repeat(N, 1, 1)), dim=1), xys2)[:, :-1, :]
axes[2, i].scatter(xys2[i, 0, :], xys2[i, 1, :], color=colors, marker='x')