|
import csv |
|
import os |
|
import random |
|
import sys |
|
from itertools import product |
|
|
|
import gdown |
|
import gradio as gr |
|
import matplotlib |
|
import matplotlib.patches as patches |
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torchvision |
|
import torchvision.models as models |
|
import torchvision.transforms as transforms |
|
import torchvision.transforms.functional as TF |
|
from matplotlib import pyplot as plt |
|
from matplotlib.patches import ConnectionPatch |
|
from PIL import Image |
|
from torch.utils.data import DataLoader |
|
|
|
from common.evaluation import Evaluator |
|
from common.logger import AverageMeter, Logger |
|
from data import download |
|
from model import chmnet |
|
from model.base.geometry import Geometry |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
|
|
|
|
md5 = "6b7b4d7bad7f89600fac340d6aa7708b" |
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", |
|
path="pas_psi.pt", |
|
quiet=False, |
|
md5=md5, |
|
) |
|
|
|
|
|
args = dict( |
|
{ |
|
"alpha": [0.05, 0.1], |
|
"benchmark": "pfpascal", |
|
"bsz": 90, |
|
"datapath": "../Datasets_CHM", |
|
"img_size": 240, |
|
"ktype": "psi", |
|
"load": "pas_psi.pt", |
|
"thres": "img", |
|
} |
|
) |
|
|
|
model = chmnet.CHMNet(args["ktype"]) |
|
model.load_state_dict(torch.load(args["load"], map_location=torch.device("cpu"))) |
|
Evaluator.initialize(args["alpha"]) |
|
Geometry.initialize(img_size=args["img_size"]) |
|
model.eval() |
|
|
|
|
|
|
|
chm_transform = transforms.Compose( |
|
[ |
|
transforms.Resize(args["img_size"]), |
|
transforms.CenterCrop((args["img_size"], args["img_size"])), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
|
] |
|
) |
|
|
|
chm_transform_plot = transforms.Compose( |
|
[ |
|
transforms.Resize(args["img_size"]), |
|
transforms.CenterCrop((args["img_size"], args["img_size"])), |
|
] |
|
) |
|
|
|
|
|
to_np = lambda x: x.data.to("cpu").numpy() |
|
|
|
|
|
cmap = matplotlib.cm.get_cmap("Spectral") |
|
rgba = cmap(0.5) |
|
colors = [] |
|
for k in range(49): |
|
colors.append(cmap(k / 49.0)) |
|
|
|
|
|
|
|
def run_chm( |
|
source_image, |
|
target_image, |
|
selected_points, |
|
number_src_points, |
|
chm_transform, |
|
display_transform, |
|
): |
|
|
|
src_img_tnsr = chm_transform(source_image).unsqueeze(0) |
|
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0) |
|
|
|
|
|
keypoints = torch.tensor(selected_points).unsqueeze(0) |
|
n_pts = torch.tensor(np.asarray([number_src_points])) |
|
|
|
|
|
with torch.no_grad(): |
|
corr_matrix = model(src_img_tnsr, tgt_img_tnsr) |
|
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False) |
|
|
|
|
|
src_points = keypoints[0].squeeze(0).squeeze(0).numpy() |
|
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy() |
|
|
|
src_points_converted = [] |
|
w, h = display_transform(source_image).size |
|
|
|
for x, y in zip(src_points[0], src_points[1]): |
|
src_points_converted.append( |
|
[int(x * w / args["img_size"]), int((y) * h / args["img_size"])] |
|
) |
|
|
|
src_points_converted = np.asarray(src_points_converted[:number_src_points]) |
|
tgt_points_converted = [] |
|
|
|
w, h = display_transform(target_image).size |
|
for x, y in zip(tgt_points[0], tgt_points[1]): |
|
tgt_points_converted.append( |
|
[int(((x + 1) / 2.0) * w), int(((y + 1) / 2.0) * h)] |
|
) |
|
|
|
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points]) |
|
|
|
tgt_grid = [] |
|
|
|
for x, y in zip(tgt_points[0], tgt_points[1]): |
|
tgt_grid.append([int(((x + 1) / 2.0) * 7), int(((y + 1) / 2.0) * 7)]) |
|
|
|
|
|
|
|
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8)) |
|
|
|
|
|
ax[0].imshow(display_transform(source_image)) |
|
ax[0].scatter( |
|
src_points_converted[:, 0], |
|
src_points_converted[:, 1], |
|
c="blue", |
|
edgecolors="white", |
|
s=50, |
|
label="Source points", |
|
) |
|
ax[0].set_title("Source Image with Selected Points") |
|
ax[0].set_xticks([]) |
|
ax[0].set_yticks([]) |
|
|
|
|
|
ax[1].imshow(display_transform(target_image)) |
|
ax[1].scatter( |
|
tgt_points_converted[:, 0], |
|
tgt_points_converted[:, 1], |
|
c="red", |
|
edgecolors="white", |
|
s=50, |
|
label="Target points", |
|
) |
|
ax[1].set_title("Target Image with Corresponding Points") |
|
ax[1].set_xticks([]) |
|
ax[1].set_yticks([]) |
|
|
|
|
|
for i, (src, tgt) in enumerate(zip(src_points_converted, tgt_points_converted)): |
|
ax[0].text(*src, str(i), color="white", bbox=dict(facecolor="black", alpha=0.5)) |
|
ax[1].text(*tgt, str(i), color="black", bbox=dict(facecolor="white", alpha=0.7)) |
|
|
|
|
|
cmap = plt.get_cmap( |
|
"gist_rainbow", 49 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ax[0].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) |
|
ax[1].legend(loc="lower right", bbox_to_anchor=(1, -0.075)) |
|
|
|
plt.tight_layout() |
|
plt.subplots_adjust(wspace=0.1, hspace=0.1) |
|
fig.suptitle("CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ", fontsize=16) |
|
return fig |
|
|
|
|
|
|
|
def generate_correspondences( |
|
sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100 |
|
): |
|
A = np.linspace(min_x, max_x, 7) |
|
B = np.linspace(min_y, max_y, 7) |
|
point_list = list(product(A, B)) |
|
new_points = np.asarray(point_list, dtype=np.float64).T |
|
return run_chm( |
|
sousrce_image, |
|
target_image, |
|
selected_points=new_points, |
|
number_src_points=49, |
|
chm_transform=chm_transform, |
|
display_transform=chm_transform_plot, |
|
) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown( |
|
""" |
|
# Correspondence Matching with Convolutional Hough Matching Networks |
|
Performs keypoint transform from a 7x7 gird on the source image to the target image. Use the sliders to adjust the grid. |
|
[Original Paper](https://arxiv.org/abs/2103.16831) - [Github Page](https://github.com/juhongm999/chm) |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
|
|
image1 = gr.Image( |
|
height=240, |
|
width=240, |
|
type="pil", |
|
label="Source Image", |
|
) |
|
|
|
|
|
image2 = gr.Image( |
|
height=240, |
|
width=240, |
|
type="pil", |
|
label="Target Image", |
|
) |
|
|
|
with gr.Row(): |
|
|
|
min_x = gr.Slider( |
|
minimum=1, |
|
maximum=240, |
|
step=1, |
|
value=15, |
|
label="Min X", |
|
) |
|
|
|
|
|
max_x = gr.Slider( |
|
minimum=1, |
|
maximum=240, |
|
step=1, |
|
value=215, |
|
label="Max X", |
|
) |
|
|
|
|
|
min_y = gr.Slider( |
|
minimum=1, |
|
maximum=240, |
|
step=1, |
|
value=15, |
|
label="Min Y", |
|
) |
|
|
|
|
|
max_y = gr.Slider( |
|
minimum=1, |
|
maximum=240, |
|
step=1, |
|
value=215, |
|
label="Max Y", |
|
) |
|
|
|
with gr.Row(): |
|
output_plot = gr.Plot() |
|
|
|
gr.Examples( |
|
[ |
|
["./examples/sample1.jpeg", "./examples/sample2.jpeg", 17, 223, 17, 223], |
|
[ |
|
"./examples/Red_Winged_Blackbird_0012_6015.jpg", |
|
"./examples/Red_Winged_Blackbird_0025_5342.jpg", |
|
17, |
|
223, |
|
17, |
|
223, |
|
], |
|
[ |
|
"./examples/Yellow_Headed_Blackbird_0026_8545.jpg", |
|
"./examples/Yellow_Headed_Blackbird_0020_8549.jpg", |
|
17, |
|
223, |
|
17, |
|
223, |
|
], |
|
], |
|
inputs=[ |
|
image1, |
|
image2, |
|
min_x, |
|
max_x, |
|
min_y, |
|
max_y, |
|
], |
|
) |
|
|
|
run_btn = gr.Button("Run") |
|
|
|
run_btn.click( |
|
generate_correspondences, |
|
inputs=[image1, image2, min_x, max_x, min_y, max_y], |
|
outputs=output_plot, |
|
) |
|
|
|
demo.launch() |
|
|