Initial Commit
Browse files- README.md +3 -27
- app.py +149 -0
- common/__pycache__/evaluation.cpython-38.pyc +0 -0
- common/__pycache__/logger.cpython-38.pyc +0 -0
- common/evaluation.py +32 -0
- common/logger.py +117 -0
- data/__pycache__/dataset.cpython-38.pyc +0 -0
- data/__pycache__/download.cpython-38.pyc +0 -0
- data/__pycache__/pfpascal.cpython-38.pyc +0 -0
- data/__pycache__/pfwillow.cpython-38.pyc +0 -0
- data/__pycache__/spair.cpython-38.pyc +0 -0
- data/dataset.py +140 -0
- data/download.py +91 -0
- data/pfpascal.py +108 -0
- data/pfwillow.py +56 -0
- data/spair.py +105 -0
- model/__pycache__/chmlearner.cpython-38.pyc +0 -0
- model/__pycache__/chmnet.cpython-38.pyc +0 -0
- model/base/__pycache__/backbone.cpython-38.pyc +0 -0
- model/base/__pycache__/chm.cpython-38.pyc +0 -0
- model/base/__pycache__/chm_kernel.cpython-38.pyc +0 -0
- model/base/__pycache__/correlation.cpython-38.pyc +0 -0
- model/base/__pycache__/geometry.cpython-38.pyc +0 -0
- model/base/backbone.py +136 -0
- model/base/chm.py +190 -0
- model/base/chm_kernel.py +66 -0
- model/base/correlation.py +68 -0
- model/base/geometry.py +133 -0
- model/chmlearner.py +52 -0
- model/chmnet.py +42 -0
- requirements.txt +10 -0
README.md
CHANGED
@@ -1,37 +1,13 @@
|
|
1 |
---
|
2 |
title: ConvolutionalHoughMatchingNetworks
|
3 |
emoji: 📚
|
4 |
-
colorFrom:
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
-
#
|
12 |
|
13 |
-
|
14 |
-
Display title for the Space
|
15 |
-
|
16 |
-
`emoji`: _string_
|
17 |
-
Space emoji (emoji-only character allowed)
|
18 |
-
|
19 |
-
`colorFrom`: _string_
|
20 |
-
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
21 |
-
|
22 |
-
`colorTo`: _string_
|
23 |
-
Color for Thumbnail gradient (red, yellow, green, blue, indigo, purple, pink, gray)
|
24 |
-
|
25 |
-
`sdk`: _string_
|
26 |
-
Can be either `gradio` or `streamlit`
|
27 |
-
|
28 |
-
`sdk_version` : _string_
|
29 |
-
Only applicable for `streamlit` SDK.
|
30 |
-
See [doc](https://hf.co/docs/hub/spaces) for more info on supported versions.
|
31 |
-
|
32 |
-
`app_file`: _string_
|
33 |
-
Path to your main application file (which contains either `gradio` or `streamlit` Python code).
|
34 |
-
Path is relative to the root of the repository.
|
35 |
-
|
36 |
-
`pinned`: _boolean_
|
37 |
-
Whether the Space stays on top of your list.
|
|
|
1 |
---
|
2 |
title: ConvolutionalHoughMatchingNetworks
|
3 |
emoji: 📚
|
4 |
+
colorFrom: red
|
5 |
colorTo: yellow
|
6 |
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
10 |
|
11 |
+
# Convolutional Hough Matching Networks
|
12 |
|
13 |
+
A demo for Convolutional Hough Matching Networks. [[Paper](https://arxiv.org/abs/2109.05221)] [[Official Github Repo](https://github.com/juhongm999/chm.git)]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
import torch
|
3 |
+
from model.base.geometry import Geometry
|
4 |
+
from common.evaluation import Evaluator
|
5 |
+
from common.logger import AverageMeter
|
6 |
+
from common.logger import Logger
|
7 |
+
from data import download
|
8 |
+
from model import chmnet
|
9 |
+
from matplotlib import pyplot as plt
|
10 |
+
from matplotlib.patches import ConnectionPatch
|
11 |
+
from PIL import Image
|
12 |
+
import numpy as np
|
13 |
+
import os
|
14 |
+
import torchvision
|
15 |
+
import torchvision.transforms as transforms
|
16 |
+
import torchvision.transforms.functional as TF
|
17 |
+
import torchvision.models as models
|
18 |
+
import torch.nn as nn
|
19 |
+
import torch.nn.functional as F
|
20 |
+
import random
|
21 |
+
import gradio as gr
|
22 |
+
|
23 |
+
# Downloading the Model
|
24 |
+
torchvision.datasets.utils.download_file_from_google_drive('1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6', '.', 'pas_psi.pt')
|
25 |
+
|
26 |
+
# Model Initialization
|
27 |
+
args = dict({
|
28 |
+
'alpha' : [0.05, 0.1],
|
29 |
+
'benchmark':'pfpascal',
|
30 |
+
'bsz':90,
|
31 |
+
'datapath':'../Datasets_CHM',
|
32 |
+
'img_size':240,
|
33 |
+
'ktype':'psi',
|
34 |
+
'load':'pas_psi.pt',
|
35 |
+
'thres':'img'
|
36 |
+
})
|
37 |
+
|
38 |
+
model = chmnet.CHMNet(args['ktype'])
|
39 |
+
model.load_state_dict(torch.load(args['load'], map_location=torch.device('cpu')))
|
40 |
+
Evaluator.initialize(args['alpha'])
|
41 |
+
Geometry.initialize(img_size=args['img_size'])
|
42 |
+
model.eval();
|
43 |
+
|
44 |
+
# Transforms
|
45 |
+
|
46 |
+
chm_transform = transforms.Compose(
|
47 |
+
[transforms.Resize(args['img_size']),
|
48 |
+
transforms.CenterCrop((args['img_size'], args['img_size'])),
|
49 |
+
transforms.ToTensor(),
|
50 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])])
|
51 |
+
|
52 |
+
chm_transform_plot = transforms.Compose(
|
53 |
+
[transforms.Resize(args['img_size']),
|
54 |
+
transforms.CenterCrop((args['img_size'], args['img_size']))])
|
55 |
+
|
56 |
+
# A Helper Function
|
57 |
+
to_np = lambda x: x.data.to('cpu').numpy()
|
58 |
+
|
59 |
+
# Colors for Plotting
|
60 |
+
cmap = matplotlib.cm.get_cmap('Spectral')
|
61 |
+
rgba = cmap(0.5)
|
62 |
+
colors = []
|
63 |
+
for k in range(49):
|
64 |
+
colors.append(cmap(k/49.0))
|
65 |
+
|
66 |
+
|
67 |
+
# CHM MODEL
|
68 |
+
def run_chm(source_image, target_image, selected_points, number_src_points , chm_transform, display_transform):
|
69 |
+
# Convert to Tensor
|
70 |
+
src_img_tnsr = chm_transform(source_image).unsqueeze(0)
|
71 |
+
tgt_img_tnsr = chm_transform(target_image).unsqueeze(0)
|
72 |
+
|
73 |
+
# Selected_points = selected_points.T
|
74 |
+
keypoints = torch.tensor(selected_points).unsqueeze(0)
|
75 |
+
n_pts = torch.tensor(np.asarray([number_src_points]))
|
76 |
+
|
77 |
+
# RUN CHM ------------------------------------------------------------------------
|
78 |
+
with torch.no_grad():
|
79 |
+
corr_matrix = model(src_img_tnsr, tgt_img_tnsr)
|
80 |
+
prd_kps = Geometry.transfer_kps(corr_matrix, keypoints, n_pts, normalized=False)
|
81 |
+
|
82 |
+
# VISUALIZATION
|
83 |
+
src_points = keypoints[0].squeeze(0).squeeze(0).numpy()
|
84 |
+
tgt_points = prd_kps[0].squeeze(0).squeeze(0).cpu().numpy()
|
85 |
+
|
86 |
+
src_points_converted = []
|
87 |
+
w, h = display_transform(source_image).size
|
88 |
+
|
89 |
+
for x,y in zip(src_points[0], src_points[1]):
|
90 |
+
src_points_converted.append([int(x*w/args['img_size']),int((y)*h/args['img_size'])])
|
91 |
+
|
92 |
+
src_points_converted = np.asarray(src_points_converted[:number_src_points])
|
93 |
+
tgt_points_converted = []
|
94 |
+
|
95 |
+
w, h = display_transform(target_image).size
|
96 |
+
for x, y in zip(tgt_points[0], tgt_points[1]):
|
97 |
+
tgt_points_converted.append([int(((x+1)/2.0)*w),int(((y+1)/2.0)*h)])
|
98 |
+
|
99 |
+
tgt_points_converted = np.asarray(tgt_points_converted[:number_src_points])
|
100 |
+
|
101 |
+
tgt_grid = []
|
102 |
+
|
103 |
+
for x, y in zip(tgt_points[0], tgt_points[1]):
|
104 |
+
tgt_grid.append([int(((x+1)/2.0)*7),int(((y+1)/2.0)*7)])
|
105 |
+
|
106 |
+
# PLOT
|
107 |
+
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(12, 8))
|
108 |
+
|
109 |
+
ax[0].imshow(display_transform(source_image))
|
110 |
+
ax[0].scatter(src_points_converted[:, 0], src_points_converted[:, 1], c=colors[:number_src_points])
|
111 |
+
ax[0].set_title('Source')
|
112 |
+
ax[0].set_xticks([])
|
113 |
+
ax[0].set_yticks([])
|
114 |
+
|
115 |
+
ax[1].imshow(display_transform(target_image))
|
116 |
+
ax[1].scatter(tgt_points_converted[:, 0], tgt_points_converted[:, 1], c=colors[:number_src_points])
|
117 |
+
ax[1].set_title('Target')
|
118 |
+
ax[1].set_xticks([])
|
119 |
+
ax[1].set_yticks([])
|
120 |
+
|
121 |
+
for TL in range(49):
|
122 |
+
ax[0].text(x=src_points_converted[TL][0], y=src_points_converted[TL][1], s=str(TL), fontdict=dict(color='red', size=10))
|
123 |
+
|
124 |
+
for TL in range(49):
|
125 |
+
ax[1].text(x=tgt_points_converted[TL][0], y=tgt_points_converted[TL][1], s=f'{str(TL)}', fontdict=dict(color='orange', size=8))
|
126 |
+
|
127 |
+
plt.tight_layout()
|
128 |
+
fig.suptitle('CHM Correspondences\nUsing $\it{pas\_psi.pt}$ Weights ', fontsize=16)
|
129 |
+
return fig
|
130 |
+
|
131 |
+
|
132 |
+
# Wrapper
|
133 |
+
def generate_correspondences(sousrce_image, target_image, min_x=1, max_x=100, min_y=1, max_y=100):
|
134 |
+
A = np.linspace(min_x, max_x, 7)
|
135 |
+
B = np.linspace(min_y, max_y, 7)
|
136 |
+
point_list = list(product(A, B))
|
137 |
+
new_points = np.asarray(point_list, dtype=np.float64).T
|
138 |
+
return run_chm(sousrce_image, target_image, selected_points=new_points, number_src_points=49, chm_transform=chm_transform, display_transform=chm_transform_plot)
|
139 |
+
|
140 |
+
|
141 |
+
# GRADIO APP
|
142 |
+
iface = gr.Interface(fn=generate_correspondences,
|
143 |
+
inputs=[gr.inputs.Image(shape=(240, 240), type='pil'),
|
144 |
+
gr.inputs.Image(shape=(240, 240), type='pil'),
|
145 |
+
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinX'),
|
146 |
+
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxX'),
|
147 |
+
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=15, label='MinY'),
|
148 |
+
gr.inputs.Slider(minimum=1, maximum=240, step=1, default=215, label='MaxY')], outputs="plot")
|
149 |
+
iface.launch()
|
common/__pycache__/evaluation.cpython-38.pyc
ADDED
Binary file (1.3 kB). View file
|
|
common/__pycache__/logger.cpython-38.pyc
ADDED
Binary file (4.23 kB). View file
|
|
common/evaluation.py
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Evaluates CHMNet with PCK """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
class Evaluator:
|
7 |
+
r""" Computes evaluation metrics of PCK """
|
8 |
+
@classmethod
|
9 |
+
def initialize(cls, alpha):
|
10 |
+
cls.alpha = torch.tensor(alpha).unsqueeze(1)
|
11 |
+
|
12 |
+
@classmethod
|
13 |
+
def evaluate(cls, prd_kps, batch):
|
14 |
+
r""" Compute percentage of correct key-points (PCK) with multiple alpha {0.05, 0.1, 0.15 }"""
|
15 |
+
|
16 |
+
pcks = []
|
17 |
+
for idx, (pk, tk) in enumerate(zip(prd_kps, batch['trg_kps'])):
|
18 |
+
pckthres = batch['pckthres'][idx]
|
19 |
+
npt = batch['n_pts'][idx]
|
20 |
+
prd_kps = pk[:, :npt]
|
21 |
+
trg_kps = tk[:, :npt]
|
22 |
+
|
23 |
+
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=0).pow(0.5).unsqueeze(0).repeat(len(cls.alpha), 1)
|
24 |
+
thres = pckthres.expand_as(l2dist).float() * cls.alpha
|
25 |
+
pck = torch.le(l2dist, thres).sum(dim=1) / float(npt)
|
26 |
+
if len(pck) == 1: pck = pck[0]
|
27 |
+
pcks.append(pck)
|
28 |
+
|
29 |
+
eval_result = {'pck': pcks}
|
30 |
+
|
31 |
+
return eval_result
|
32 |
+
|
common/logger.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Logging """
|
2 |
+
|
3 |
+
import datetime
|
4 |
+
import logging
|
5 |
+
import os
|
6 |
+
|
7 |
+
from tensorboardX import SummaryWriter
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class Logger:
|
12 |
+
r""" Writes results of training/testing """
|
13 |
+
@classmethod
|
14 |
+
def initialize(cls, args, training):
|
15 |
+
logtime = datetime.datetime.now().__format__('_%m%d_%H%M%S')
|
16 |
+
logpath = args.logpath if training else '_TEST_' + args.load.split('/')[-1].split('.')[0] + logtime
|
17 |
+
if logpath == '': logpath = logtime
|
18 |
+
|
19 |
+
cls.logpath = os.path.join('logs', logpath + '.log')
|
20 |
+
cls.benchmark = args.benchmark
|
21 |
+
os.makedirs(cls.logpath)
|
22 |
+
|
23 |
+
logging.basicConfig(filemode='w',
|
24 |
+
filename=os.path.join(cls.logpath, 'log.txt'),
|
25 |
+
level=logging.INFO,
|
26 |
+
format='%(message)s',
|
27 |
+
datefmt='%m-%d %H:%M:%S')
|
28 |
+
|
29 |
+
# Console log config
|
30 |
+
console = logging.StreamHandler()
|
31 |
+
console.setLevel(logging.INFO)
|
32 |
+
formatter = logging.Formatter('%(message)s')
|
33 |
+
console.setFormatter(formatter)
|
34 |
+
logging.getLogger('').addHandler(console)
|
35 |
+
|
36 |
+
# Tensorboard writer
|
37 |
+
cls.tbd_writer = SummaryWriter(os.path.join(cls.logpath, 'tbd/runs'))
|
38 |
+
|
39 |
+
# Log arguments
|
40 |
+
if training:
|
41 |
+
logging.info(':======== Convolutional Hough Matching Networks =========')
|
42 |
+
for arg_key in args.__dict__:
|
43 |
+
logging.info('| %20s: %-24s' % (arg_key, str(args.__dict__[arg_key])))
|
44 |
+
logging.info(':========================================================\n')
|
45 |
+
|
46 |
+
@classmethod
|
47 |
+
def info(cls, msg):
|
48 |
+
r""" Writes message to .txt """
|
49 |
+
logging.info(msg)
|
50 |
+
|
51 |
+
@classmethod
|
52 |
+
def save_model(cls, model, epoch, val_pck):
|
53 |
+
torch.save(model.state_dict(), os.path.join(cls.logpath, 'pck_best_model.pt'))
|
54 |
+
cls.info('Model saved @%d w/ val. PCK: %5.2f.\n' % (epoch, val_pck))
|
55 |
+
|
56 |
+
|
57 |
+
class AverageMeter:
|
58 |
+
r""" Stores loss, evaluation results, selected layers """
|
59 |
+
def __init__(self, benchamrk):
|
60 |
+
r""" Constructor of AverageMeter """
|
61 |
+
self.buffer_keys = ['pck']
|
62 |
+
self.buffer = {}
|
63 |
+
for key in self.buffer_keys:
|
64 |
+
self.buffer[key] = []
|
65 |
+
|
66 |
+
self.loss_buffer = []
|
67 |
+
|
68 |
+
def update(self, eval_result, loss=None):
|
69 |
+
for key in self.buffer_keys:
|
70 |
+
self.buffer[key] += eval_result[key]
|
71 |
+
|
72 |
+
if loss is not None:
|
73 |
+
self.loss_buffer.append(loss)
|
74 |
+
|
75 |
+
def write_result(self, split, epoch):
|
76 |
+
msg = '\n*** %s ' % split
|
77 |
+
msg += '[@Epoch %02d] ' % epoch
|
78 |
+
|
79 |
+
if len(self.loss_buffer) > 0:
|
80 |
+
msg += 'Loss: %5.2f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
|
81 |
+
|
82 |
+
for key in self.buffer_keys:
|
83 |
+
msg += '%s: %6.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
|
84 |
+
msg += '***\n'
|
85 |
+
Logger.info(msg)
|
86 |
+
|
87 |
+
def write_process(self, batch_idx, datalen, epoch):
|
88 |
+
msg = '[Epoch: %02d] ' % epoch
|
89 |
+
msg += '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
90 |
+
if len(self.loss_buffer) > 0:
|
91 |
+
msg += 'Loss: %5.2f ' % self.loss_buffer[-1]
|
92 |
+
msg += 'Avg Loss: %5.5f ' % (sum(self.loss_buffer) / len(self.loss_buffer))
|
93 |
+
|
94 |
+
for key in self.buffer_keys:
|
95 |
+
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]) * 100)
|
96 |
+
Logger.info(msg)
|
97 |
+
|
98 |
+
def write_test_process(self, batch_idx, datalen):
|
99 |
+
msg = '[Batch: %04d/%04d] ' % (batch_idx+1, datalen)
|
100 |
+
|
101 |
+
for key in self.buffer_keys:
|
102 |
+
if key == 'pck':
|
103 |
+
pcks = torch.stack(self.buffer[key]).mean(dim=0) * 100
|
104 |
+
val = ''
|
105 |
+
for p in pcks:
|
106 |
+
val += '%5.2f ' % p.item()
|
107 |
+
msg += 'Avg %s: %s ' % (key.upper(), val)
|
108 |
+
else:
|
109 |
+
msg += 'Avg %s: %5.2f ' % (key.upper(), sum(self.buffer[key]) / len(self.buffer[key]))
|
110 |
+
Logger.info(msg)
|
111 |
+
|
112 |
+
def get_test_result(self):
|
113 |
+
result = {}
|
114 |
+
for key in self.buffer_keys:
|
115 |
+
result[key] = torch.stack(self.buffer[key]).mean(dim=0) * 100
|
116 |
+
|
117 |
+
return result
|
data/__pycache__/dataset.cpython-38.pyc
ADDED
Binary file (3.95 kB). View file
|
|
data/__pycache__/download.cpython-38.pyc
ADDED
Binary file (2.56 kB). View file
|
|
data/__pycache__/pfpascal.cpython-38.pyc
ADDED
Binary file (3.91 kB). View file
|
|
data/__pycache__/pfwillow.cpython-38.pyc
ADDED
Binary file (2.85 kB). View file
|
|
data/__pycache__/spair.cpython-38.pyc
ADDED
Binary file (5.51 kB). View file
|
|
data/dataset.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Superclass for semantic correspondence datasets """
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from torchvision import transforms
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from model.base.geometry import Geometry
|
11 |
+
|
12 |
+
|
13 |
+
class CorrespondenceDataset(Dataset):
|
14 |
+
r""" Parent class of PFPascal, PFWillow, and SPair """
|
15 |
+
def __init__(self, benchmark, datapath, thres, split):
|
16 |
+
r""" CorrespondenceDataset constructor """
|
17 |
+
super(CorrespondenceDataset, self).__init__()
|
18 |
+
|
19 |
+
# {Directory name, Layout path, Image path, Annotation path, PCK threshold}
|
20 |
+
self.metadata = {
|
21 |
+
'pfwillow': ('PF-WILLOW',
|
22 |
+
'test_pairs.csv',
|
23 |
+
'',
|
24 |
+
'',
|
25 |
+
'bbox'),
|
26 |
+
'pfpascal': ('PF-PASCAL',
|
27 |
+
'_pairs.csv',
|
28 |
+
'JPEGImages',
|
29 |
+
'Annotations',
|
30 |
+
'img'),
|
31 |
+
'spair': ('SPair-71k',
|
32 |
+
'Layout/large',
|
33 |
+
'JPEGImages',
|
34 |
+
'PairAnnotation',
|
35 |
+
'bbox')
|
36 |
+
}
|
37 |
+
|
38 |
+
# Directory path for train, val, or test splits
|
39 |
+
base_path = os.path.join(os.path.abspath(datapath), self.metadata[benchmark][0])
|
40 |
+
if benchmark == 'pfpascal':
|
41 |
+
self.spt_path = os.path.join(base_path, split+'_pairs.csv')
|
42 |
+
elif benchmark == 'spair':
|
43 |
+
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1], split+'.txt')
|
44 |
+
else:
|
45 |
+
self.spt_path = os.path.join(base_path, self.metadata[benchmark][1])
|
46 |
+
|
47 |
+
# Directory path for images
|
48 |
+
self.img_path = os.path.join(base_path, self.metadata[benchmark][2])
|
49 |
+
|
50 |
+
# Directory path for annotations
|
51 |
+
if benchmark == 'spair':
|
52 |
+
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3], split)
|
53 |
+
else:
|
54 |
+
self.ann_path = os.path.join(base_path, self.metadata[benchmark][3])
|
55 |
+
|
56 |
+
# Miscellaneous
|
57 |
+
self.max_pts = 40
|
58 |
+
self.split = split
|
59 |
+
self.img_size = Geometry.img_size
|
60 |
+
self.benchmark = benchmark
|
61 |
+
self.range_ts = torch.arange(self.max_pts)
|
62 |
+
self.thres = self.metadata[benchmark][4] if thres == 'auto' else thres
|
63 |
+
self.transform = transforms.Compose([transforms.Resize((self.img_size, self.img_size)),
|
64 |
+
transforms.ToTensor(),
|
65 |
+
transforms.Normalize(mean=[0.485, 0.456, 0.406],
|
66 |
+
std=[0.229, 0.224, 0.225])])
|
67 |
+
|
68 |
+
# To get initialized in subclass constructors
|
69 |
+
self.train_data = []
|
70 |
+
self.src_imnames = []
|
71 |
+
self.trg_imnames = []
|
72 |
+
self.cls = []
|
73 |
+
self.cls_ids = []
|
74 |
+
self.src_kps = []
|
75 |
+
self.trg_kps = []
|
76 |
+
|
77 |
+
def __len__(self):
|
78 |
+
r""" Returns the number of pairs """
|
79 |
+
return len(self.train_data)
|
80 |
+
|
81 |
+
def __getitem__(self, idx):
|
82 |
+
r""" Constructs and return a batch """
|
83 |
+
|
84 |
+
# Image name
|
85 |
+
batch = dict()
|
86 |
+
batch['src_imname'] = self.src_imnames[idx]
|
87 |
+
batch['trg_imname'] = self.trg_imnames[idx]
|
88 |
+
|
89 |
+
# Object category
|
90 |
+
batch['category_id'] = self.cls_ids[idx]
|
91 |
+
batch['category'] = self.cls[batch['category_id']]
|
92 |
+
|
93 |
+
# Image as numpy (original width, original height)
|
94 |
+
src_pil = self.get_image(self.src_imnames, idx)
|
95 |
+
trg_pil = self.get_image(self.trg_imnames, idx)
|
96 |
+
batch['src_imsize'] = src_pil.size
|
97 |
+
batch['trg_imsize'] = trg_pil.size
|
98 |
+
|
99 |
+
# Image as tensor
|
100 |
+
batch['src_img'] = self.transform(src_pil)
|
101 |
+
batch['trg_img'] = self.transform(trg_pil)
|
102 |
+
|
103 |
+
# Key-points (re-scaled)
|
104 |
+
batch['src_kps'], num_pts = self.get_points(self.src_kps, idx, src_pil.size)
|
105 |
+
batch['trg_kps'], _ = self.get_points(self.trg_kps, idx, trg_pil.size)
|
106 |
+
batch['n_pts'] = torch.tensor(num_pts)
|
107 |
+
|
108 |
+
# Total number of pairs in training split
|
109 |
+
batch['datalen'] = len(self.train_data)
|
110 |
+
|
111 |
+
return batch
|
112 |
+
|
113 |
+
def get_image(self, imnames, idx):
|
114 |
+
r""" Reads PIL image from path """
|
115 |
+
path = os.path.join(self.img_path, imnames[idx])
|
116 |
+
return Image.open(path).convert('RGB')
|
117 |
+
|
118 |
+
def get_pckthres(self, batch, imsize):
|
119 |
+
r""" Computes PCK threshold """
|
120 |
+
if self.thres == 'bbox':
|
121 |
+
bbox = batch['trg_bbox'].clone()
|
122 |
+
bbox_w = (bbox[2] - bbox[0])
|
123 |
+
bbox_h = (bbox[3] - bbox[1])
|
124 |
+
pckthres = torch.max(bbox_w, bbox_h)
|
125 |
+
elif self.thres == 'img':
|
126 |
+
imsize_t = batch['trg_img'].size()
|
127 |
+
pckthres = torch.tensor(max(imsize_t[1], imsize_t[2]))
|
128 |
+
else:
|
129 |
+
raise Exception('Invalid pck threshold type: %s' % self.thres)
|
130 |
+
return pckthres.float()
|
131 |
+
|
132 |
+
def get_points(self, pts_list, idx, org_imsize):
|
133 |
+
r""" Returns key-points of an image """
|
134 |
+
xy, n_pts = pts_list[idx].size()
|
135 |
+
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
|
136 |
+
x_crds = pts_list[idx][0] * (self.img_size / org_imsize[0])
|
137 |
+
y_crds = pts_list[idx][1] * (self.img_size / org_imsize[1])
|
138 |
+
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
|
139 |
+
|
140 |
+
return kps, n_pts
|
data/download.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Functions to download semantic correspondence datasets """
|
2 |
+
|
3 |
+
import tarfile
|
4 |
+
import os
|
5 |
+
|
6 |
+
import requests
|
7 |
+
|
8 |
+
from . import pfpascal
|
9 |
+
from . import pfwillow
|
10 |
+
from . import spair
|
11 |
+
|
12 |
+
|
13 |
+
def load_dataset(benchmark, datapath, thres, split='test'):
|
14 |
+
r""" Instantiate a correspondence dataset """
|
15 |
+
correspondence_benchmark = {
|
16 |
+
'spair': spair.SPairDataset,
|
17 |
+
'pfpascal': pfpascal.PFPascalDataset,
|
18 |
+
'pfwillow': pfwillow.PFWillowDataset
|
19 |
+
}
|
20 |
+
|
21 |
+
dataset = correspondence_benchmark.get(benchmark)
|
22 |
+
if dataset is None:
|
23 |
+
raise Exception('Invalid benchmark dataset %s.' % benchmark)
|
24 |
+
|
25 |
+
return dataset(benchmark, datapath, thres, split)
|
26 |
+
|
27 |
+
|
28 |
+
def download_from_google(token_id, filename):
|
29 |
+
r""" Download desired filename from Google drive """
|
30 |
+
|
31 |
+
print('Downloading %s ...' % os.path.basename(filename))
|
32 |
+
|
33 |
+
url = 'https://docs.google.com/uc?export=download'
|
34 |
+
destination = filename + '.tar.gz'
|
35 |
+
session = requests.Session()
|
36 |
+
|
37 |
+
response = session.get(url, params={'id': token_id}, stream=True)
|
38 |
+
token = get_confirm_token(response)
|
39 |
+
|
40 |
+
if token:
|
41 |
+
params = {'id': token_id, 'confirm': token}
|
42 |
+
response = session.get(url, params=params, stream=True)
|
43 |
+
save_response_content(response, destination)
|
44 |
+
file = tarfile.open(destination, 'r:gz')
|
45 |
+
|
46 |
+
print("Extracting %s ..." % destination)
|
47 |
+
file.extractall(filename)
|
48 |
+
file.close()
|
49 |
+
|
50 |
+
os.remove(destination)
|
51 |
+
os.rename(filename, filename + '_tmp')
|
52 |
+
os.rename(os.path.join(filename + '_tmp', os.path.basename(filename)), filename)
|
53 |
+
os.rmdir(filename+'_tmp')
|
54 |
+
|
55 |
+
|
56 |
+
def get_confirm_token(response):
|
57 |
+
r"""Retrieves confirm token"""
|
58 |
+
for key, value in response.cookies.items():
|
59 |
+
if key.startswith('download_warning'):
|
60 |
+
return value
|
61 |
+
|
62 |
+
return None
|
63 |
+
|
64 |
+
|
65 |
+
def save_response_content(response, destination):
|
66 |
+
r"""Saves the response to the destination"""
|
67 |
+
chunk_size = 32768
|
68 |
+
|
69 |
+
with open(destination, "wb") as file:
|
70 |
+
for chunk in response.iter_content(chunk_size):
|
71 |
+
if chunk:
|
72 |
+
file.write(chunk)
|
73 |
+
|
74 |
+
|
75 |
+
def download_dataset(datapath, benchmark):
|
76 |
+
r"""Downloads semantic correspondence benchmark dataset from Google drive"""
|
77 |
+
if not os.path.isdir(datapath):
|
78 |
+
os.mkdir(datapath)
|
79 |
+
|
80 |
+
file_data = {
|
81 |
+
# 'spair': ('1s73NVEFPro260H1tXxCh1ain7oApR8of', 'SPair-71k') old version
|
82 |
+
'spair': ('1KSvB0k2zXA06ojWNvFjBv0Ake426Y76k', 'SPair-71k'),
|
83 |
+
'pfpascal': ('1OOwpGzJnTsFXYh-YffMQ9XKM_Kl_zdzg', 'PF-PASCAL'),
|
84 |
+
'pfwillow': ('1tDP0y8RO5s45L-vqnortRaieiWENQco_', 'PF-WILLOW')
|
85 |
+
}
|
86 |
+
|
87 |
+
file_id, filename = file_data[benchmark]
|
88 |
+
abs_filepath = os.path.join(datapath, filename)
|
89 |
+
|
90 |
+
if not os.path.isdir(abs_filepath):
|
91 |
+
download_from_google(file_id, abs_filepath)
|
data/pfpascal.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" PF-PASCAL dataset """
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import scipy.io as sio
|
6 |
+
import pandas as pd
|
7 |
+
import numpy as np
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from .dataset import CorrespondenceDataset
|
11 |
+
|
12 |
+
|
13 |
+
class PFPascalDataset(CorrespondenceDataset):
|
14 |
+
|
15 |
+
def __init__(self, benchmark, datapath, thres, split):
|
16 |
+
r""" PF-PASCAL dataset constructor """
|
17 |
+
super(PFPascalDataset, self).__init__(benchmark, datapath, thres, split)
|
18 |
+
|
19 |
+
self.train_data = pd.read_csv(self.spt_path)
|
20 |
+
self.src_imnames = np.array(self.train_data.iloc[:, 0])
|
21 |
+
self.trg_imnames = np.array(self.train_data.iloc[:, 1])
|
22 |
+
self.cls = ['aeroplane', 'bicycle', 'bird', 'boat', 'bottle',
|
23 |
+
'bus', 'car', 'cat', 'chair', 'cow',
|
24 |
+
'diningtable', 'dog', 'horse', 'motorbike', 'person',
|
25 |
+
'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor']
|
26 |
+
self.cls_ids = self.train_data.iloc[:, 2].values.astype('int') - 1
|
27 |
+
|
28 |
+
if split == 'trn':
|
29 |
+
self.flip = self.train_data.iloc[:, 3].values.astype('int')
|
30 |
+
self.src_kps = []
|
31 |
+
self.trg_kps = []
|
32 |
+
self.src_bbox = []
|
33 |
+
self.trg_bbox = []
|
34 |
+
for src_imname, trg_imname, cls in zip(self.src_imnames, self.trg_imnames, self.cls_ids):
|
35 |
+
src_anns = os.path.join(self.ann_path, self.cls[cls],
|
36 |
+
os.path.basename(src_imname))[:-4] + '.mat'
|
37 |
+
trg_anns = os.path.join(self.ann_path, self.cls[cls],
|
38 |
+
os.path.basename(trg_imname))[:-4] + '.mat'
|
39 |
+
|
40 |
+
src_kp = torch.tensor(read_mat(src_anns, 'kps')).float()
|
41 |
+
trg_kp = torch.tensor(read_mat(trg_anns, 'kps')).float()
|
42 |
+
src_box = torch.tensor(read_mat(src_anns, 'bbox')[0].astype(float))
|
43 |
+
trg_box = torch.tensor(read_mat(trg_anns, 'bbox')[0].astype(float))
|
44 |
+
|
45 |
+
src_kps = []
|
46 |
+
trg_kps = []
|
47 |
+
for src_kk, trg_kk in zip(src_kp, trg_kp):
|
48 |
+
if len(torch.isnan(src_kk).nonzero()) != 0 or \
|
49 |
+
len(torch.isnan(trg_kk).nonzero()) != 0:
|
50 |
+
continue
|
51 |
+
else:
|
52 |
+
src_kps.append(src_kk)
|
53 |
+
trg_kps.append(trg_kk)
|
54 |
+
self.src_kps.append(torch.stack(src_kps).t())
|
55 |
+
self.trg_kps.append(torch.stack(trg_kps).t())
|
56 |
+
self.src_bbox.append(src_box)
|
57 |
+
self.trg_bbox.append(trg_box)
|
58 |
+
|
59 |
+
self.src_imnames = list(map(lambda x: os.path.basename(x), self.src_imnames))
|
60 |
+
self.trg_imnames = list(map(lambda x: os.path.basename(x), self.trg_imnames))
|
61 |
+
|
62 |
+
def __getitem__(self, idx):
|
63 |
+
r""" Constructs and returns a batch for PF-PASCAL dataset """
|
64 |
+
batch = super(PFPascalDataset, self).__getitem__(idx)
|
65 |
+
|
66 |
+
# Object bounding-box (resized following self.img_size)
|
67 |
+
batch['src_bbox'] = self.get_bbox(self.src_bbox, idx, batch['src_imsize'])
|
68 |
+
batch['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, batch['trg_imsize'])
|
69 |
+
batch['pckthres'] = self.get_pckthres(batch, batch['trg_imsize'])
|
70 |
+
|
71 |
+
# Horizontal flipping key-points during training
|
72 |
+
if self.split == 'trn' and self.flip[idx]:
|
73 |
+
self.horizontal_flip(batch)
|
74 |
+
batch['flip'] = 1
|
75 |
+
else:
|
76 |
+
batch['flip'] = 0
|
77 |
+
|
78 |
+
return batch
|
79 |
+
|
80 |
+
def get_bbox(self, bbox_list, idx, imsize):
|
81 |
+
r""" Returns object bounding-box """
|
82 |
+
bbox = bbox_list[idx].clone()
|
83 |
+
bbox[0::2] *= (self.img_size / imsize[0])
|
84 |
+
bbox[1::2] *= (self.img_size / imsize[1])
|
85 |
+
return bbox
|
86 |
+
|
87 |
+
def horizontal_flip(self, batch):
|
88 |
+
tmp = batch['src_bbox'][0].clone()
|
89 |
+
batch['src_bbox'][0] = batch['src_img'].size(2) - batch['src_bbox'][2]
|
90 |
+
batch['src_bbox'][2] = batch['src_img'].size(2) - tmp
|
91 |
+
|
92 |
+
tmp = batch['trg_bbox'][0].clone()
|
93 |
+
batch['trg_bbox'][0] = batch['trg_img'].size(2) - batch['trg_bbox'][2]
|
94 |
+
batch['trg_bbox'][2] = batch['trg_img'].size(2) - tmp
|
95 |
+
|
96 |
+
batch['src_kps'][0][:batch['n_pts']] = batch['src_img'].size(2) - batch['src_kps'][0][:batch['n_pts']]
|
97 |
+
batch['trg_kps'][0][:batch['n_pts']] = batch['trg_img'].size(2) - batch['trg_kps'][0][:batch['n_pts']]
|
98 |
+
|
99 |
+
batch['src_img'] = torch.flip(batch['src_img'], dims=(2,))
|
100 |
+
batch['trg_img'] = torch.flip(batch['trg_img'], dims=(2,))
|
101 |
+
|
102 |
+
|
103 |
+
def read_mat(path, obj_name):
|
104 |
+
r""" Reads specified objects from Matlab data file. (.mat) """
|
105 |
+
mat_contents = sio.loadmat(path)
|
106 |
+
mat_obj = mat_contents[obj_name]
|
107 |
+
|
108 |
+
return mat_obj
|
data/pfwillow.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" PF-WILLOW dataset """
|
2 |
+
|
3 |
+
import os
|
4 |
+
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
|
9 |
+
from .dataset import CorrespondenceDataset
|
10 |
+
|
11 |
+
|
12 |
+
class PFWillowDataset(CorrespondenceDataset):
|
13 |
+
|
14 |
+
def __init__(self, benchmark, datapath, thres, split):
|
15 |
+
r"""PF-WILLOW dataset constructor"""
|
16 |
+
super(PFWillowDataset, self).__init__(benchmark, datapath, thres, split)
|
17 |
+
|
18 |
+
self.train_data = pd.read_csv(self.spt_path)
|
19 |
+
self.src_imnames = np.array(self.train_data.iloc[:, 0])
|
20 |
+
self.trg_imnames = np.array(self.train_data.iloc[:, 1])
|
21 |
+
self.src_kps = self.train_data.iloc[:, 2:22].values
|
22 |
+
self.trg_kps = self.train_data.iloc[:, 22:].values
|
23 |
+
self.cls = ['car(G)', 'car(M)', 'car(S)', 'duck(S)',
|
24 |
+
'motorbike(G)', 'motorbike(M)', 'motorbike(S)',
|
25 |
+
'winebottle(M)', 'winebottle(wC)', 'winebottle(woC)']
|
26 |
+
self.cls_ids = list(map(lambda names: self.cls.index(names.split('/')[1]), self.src_imnames))
|
27 |
+
self.src_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.src_imnames))
|
28 |
+
self.trg_imnames = list(map(lambda x: os.path.join(*x.split('/')[1:]), self.trg_imnames))
|
29 |
+
|
30 |
+
def __getitem__(self, idx):
|
31 |
+
r""" Constructs and returns a batch for PF-WILLOW dataset """
|
32 |
+
batch = super(PFWillowDataset, self).__getitem__(idx)
|
33 |
+
batch['pckthres'] = self.get_pckthres(batch)
|
34 |
+
|
35 |
+
return batch
|
36 |
+
|
37 |
+
def get_pckthres(self, batch):
|
38 |
+
r""" Computes PCK threshold """
|
39 |
+
if self.thres == 'bbox':
|
40 |
+
return max(batch['trg_kps'].max(1)[0] - batch['trg_kps'].min(1)[0]).clone()
|
41 |
+
elif self.thres == 'img':
|
42 |
+
return torch.tensor(max(batch['trg_img'].size()[1], batch['trg_img'].size()[2]))
|
43 |
+
else:
|
44 |
+
raise Exception('Invalid pck evaluation level: %s' % self.thres)
|
45 |
+
|
46 |
+
def get_points(self, pts_list, idx, org_imsize):
|
47 |
+
r""" Returns key-points of an image """
|
48 |
+
point_coords = pts_list[idx, :].reshape(2, 10)
|
49 |
+
point_coords = torch.tensor(point_coords.astype(np.float32))
|
50 |
+
xy, n_pts = point_coords.size()
|
51 |
+
pad_pts = torch.zeros((xy, self.max_pts - n_pts)) - 2
|
52 |
+
x_crds = point_coords[0] * (self.img_size / org_imsize[0])
|
53 |
+
y_crds = point_coords[1] * (self.img_size / org_imsize[1])
|
54 |
+
kps = torch.cat([torch.stack([x_crds, y_crds]), pad_pts], dim=1)
|
55 |
+
|
56 |
+
return kps, n_pts
|
data/spair.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" SPair-71k dataset """
|
2 |
+
|
3 |
+
import json
|
4 |
+
import glob
|
5 |
+
import os
|
6 |
+
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
from .dataset import CorrespondenceDataset
|
13 |
+
|
14 |
+
|
15 |
+
class SPairDataset(CorrespondenceDataset):
|
16 |
+
|
17 |
+
def __init__(self, benchmark, datapath, thres, split):
|
18 |
+
r""" SPair-71k dataset constructor """
|
19 |
+
super(SPairDataset, self).__init__(benchmark, datapath, thres, split)
|
20 |
+
|
21 |
+
self.train_data = open(self.spt_path).read().split('\n')
|
22 |
+
self.train_data = self.train_data[:len(self.train_data) - 1]
|
23 |
+
self.src_imnames = list(map(lambda x: x.split('-')[1] + '.jpg', self.train_data))
|
24 |
+
self.trg_imnames = list(map(lambda x: x.split('-')[2].split(':')[0] + '.jpg', self.train_data))
|
25 |
+
self.seg_path = os.path.abspath(os.path.join(self.img_path, os.pardir, 'Segmentation'))
|
26 |
+
self.cls = os.listdir(self.img_path)
|
27 |
+
self.cls.sort()
|
28 |
+
|
29 |
+
anntn_files = []
|
30 |
+
for data_name in self.train_data:
|
31 |
+
anntn_files.append(glob.glob('%s/%s.json' % (self.ann_path, data_name))[0])
|
32 |
+
anntn_files = list(map(lambda x: json.load(open(x)), anntn_files))
|
33 |
+
self.src_kps = list(map(lambda x: torch.tensor(x['src_kps']).t().float(), anntn_files))
|
34 |
+
self.trg_kps = list(map(lambda x: torch.tensor(x['trg_kps']).t().float(), anntn_files))
|
35 |
+
self.src_bbox = list(map(lambda x: torch.tensor(x['src_bndbox']).float(), anntn_files))
|
36 |
+
self.trg_bbox = list(map(lambda x: torch.tensor(x['trg_bndbox']).float(), anntn_files))
|
37 |
+
self.cls_ids = list(map(lambda x: self.cls.index(x['category']), anntn_files))
|
38 |
+
|
39 |
+
self.vpvar = list(map(lambda x: torch.tensor(x['viewpoint_variation']), anntn_files))
|
40 |
+
self.scvar = list(map(lambda x: torch.tensor(x['scale_variation']), anntn_files))
|
41 |
+
self.trncn = list(map(lambda x: torch.tensor(x['truncation']), anntn_files))
|
42 |
+
self.occln = list(map(lambda x: torch.tensor(x['occlusion']), anntn_files))
|
43 |
+
|
44 |
+
def __getitem__(self, idx):
|
45 |
+
r""" Construct and return a batch for SPair-71k dataset """
|
46 |
+
sample = super(SPairDataset, self).__getitem__(idx)
|
47 |
+
|
48 |
+
sample['src_mask'] = self.get_mask(sample, sample['src_imname'])
|
49 |
+
sample['trg_mask'] = self.get_mask(sample, sample['trg_imname'])
|
50 |
+
|
51 |
+
sample['src_bbox'] = self.get_bbox(self.src_bbox, idx, sample['src_imsize'])
|
52 |
+
sample['trg_bbox'] = self.get_bbox(self.trg_bbox, idx, sample['trg_imsize'])
|
53 |
+
sample['pckthres'] = self.get_pckthres(sample, sample['trg_imsize'])
|
54 |
+
|
55 |
+
sample['vpvar'] = self.vpvar[idx]
|
56 |
+
sample['scvar'] = self.scvar[idx]
|
57 |
+
sample['trncn'] = self.trncn[idx]
|
58 |
+
sample['occln'] = self.occln[idx]
|
59 |
+
|
60 |
+
return sample
|
61 |
+
|
62 |
+
def get_mask(self, sample, imname):
|
63 |
+
mask_path = os.path.join(self.seg_path, sample['category'], imname.split('.')[0] + '.png')
|
64 |
+
|
65 |
+
tensor_mask = torch.tensor(np.array(Image.open(mask_path)))
|
66 |
+
|
67 |
+
class_dict = {'aeroplane': 0, 'bicycle': 1, 'bird': 2, 'boat': 3, 'bottle': 4,
|
68 |
+
'bus': 5, 'car': 6, 'cat': 7, 'chair': 8, 'cow': 9,
|
69 |
+
'diningtable': 10, 'dog': 11, 'horse': 12, 'motorbike': 13, 'person': 14,
|
70 |
+
'pottedplant': 15, 'sheep': 16, 'sofa': 17, 'train': 18, 'tvmonitor': 19}
|
71 |
+
|
72 |
+
class_id = class_dict[sample['category']] + 1
|
73 |
+
tensor_mask[tensor_mask != class_id] = 0
|
74 |
+
tensor_mask[tensor_mask == class_id] = 255
|
75 |
+
|
76 |
+
tensor_mask = F.interpolate(tensor_mask.unsqueeze(0).unsqueeze(0).float(),
|
77 |
+
size=(self.img_size, self.img_size),
|
78 |
+
mode='bilinear', align_corners=True).int().squeeze()
|
79 |
+
|
80 |
+
return tensor_mask
|
81 |
+
|
82 |
+
def get_image(self, img_names, idx):
|
83 |
+
r""" Return image tensor """
|
84 |
+
path = os.path.join(self.img_path, self.cls[self.cls_ids[idx]], img_names[idx])
|
85 |
+
|
86 |
+
return Image.open(path).convert('RGB')
|
87 |
+
|
88 |
+
def get_pckthres(self, sample, imsize):
|
89 |
+
r""" Compute PCK threshold """
|
90 |
+
return super(SPairDataset, self).get_pckthres(sample, imsize)
|
91 |
+
|
92 |
+
def get_points(self, pts_list, idx, imsize):
|
93 |
+
r""" Return key-points of an image """
|
94 |
+
return super(SPairDataset, self).get_points(pts_list, idx, imsize)
|
95 |
+
|
96 |
+
def match_idx(self, kps, n_pts):
|
97 |
+
r""" Sample the nearst feature (receptive field) indices """
|
98 |
+
return super(SPairDataset, self).match_idx(kps, n_pts)
|
99 |
+
|
100 |
+
def get_bbox(self, bbox_list, idx, imsize):
|
101 |
+
r""" Return object bounding-box """
|
102 |
+
bbox = bbox_list[idx].clone()
|
103 |
+
bbox[0::2] *= (self.img_size / imsize[0])
|
104 |
+
bbox[1::2] *= (self.img_size / imsize[1])
|
105 |
+
return bbox
|
model/__pycache__/chmlearner.cpython-38.pyc
ADDED
Binary file (1.85 kB). View file
|
|
model/__pycache__/chmnet.cpython-38.pyc
ADDED
Binary file (1.8 kB). View file
|
|
model/base/__pycache__/backbone.cpython-38.pyc
ADDED
Binary file (4.14 kB). View file
|
|
model/base/__pycache__/chm.cpython-38.pyc
ADDED
Binary file (6.85 kB). View file
|
|
model/base/__pycache__/chm_kernel.cpython-38.pyc
ADDED
Binary file (2.03 kB). View file
|
|
model/base/__pycache__/correlation.cpython-38.pyc
ADDED
Binary file (2.09 kB). View file
|
|
model/base/__pycache__/geometry.cpython-38.pyc
ADDED
Binary file (4.69 kB). View file
|
|
model/base/backbone.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" ResNet-101 backbone network """
|
2 |
+
|
3 |
+
import torch.utils.model_zoo as model_zoo
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = ['Backbone', 'resnet101']
|
9 |
+
|
10 |
+
|
11 |
+
model_urls = {
|
12 |
+
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
|
13 |
+
'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
|
14 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
|
15 |
+
'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
|
16 |
+
'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
|
17 |
+
}
|
18 |
+
|
19 |
+
|
20 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
21 |
+
r""" 3x3 convolution with padding """
|
22 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
23 |
+
padding=1, groups=2, bias=False)
|
24 |
+
|
25 |
+
|
26 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
27 |
+
r""" 1x1 convolution """
|
28 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, groups=2, bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class Bottleneck(nn.Module):
|
32 |
+
expansion = 4
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
35 |
+
super(Bottleneck, self).__init__()
|
36 |
+
self.conv1 = conv1x1(inplanes, planes)
|
37 |
+
self.bn1 = nn.BatchNorm2d(planes)
|
38 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes)
|
40 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
41 |
+
self.bn3 = nn.BatchNorm2d(planes * self.expansion)
|
42 |
+
self.relu = nn.ReLU(inplace=True)
|
43 |
+
self.downsample = downsample
|
44 |
+
self.stride = stride
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
identity = x
|
48 |
+
|
49 |
+
out = self.conv1(x)
|
50 |
+
out = self.bn1(out)
|
51 |
+
out = self.relu(out)
|
52 |
+
|
53 |
+
out = self.conv2(out)
|
54 |
+
out = self.bn2(out)
|
55 |
+
out = self.relu(out)
|
56 |
+
|
57 |
+
out = self.conv3(out)
|
58 |
+
out = self.bn3(out)
|
59 |
+
|
60 |
+
if self.downsample is not None:
|
61 |
+
identity = self.downsample(x)
|
62 |
+
|
63 |
+
out += identity
|
64 |
+
out = self.relu(out)
|
65 |
+
|
66 |
+
return out
|
67 |
+
|
68 |
+
|
69 |
+
class Backbone(nn.Module):
|
70 |
+
def __init__(self, block, layers, zero_init_residual=False):
|
71 |
+
super(Backbone, self).__init__()
|
72 |
+
|
73 |
+
self.inplanes = 128
|
74 |
+
self.conv1 = nn.Conv2d(6, 128, kernel_size=7, stride=2, padding=3, groups=2,
|
75 |
+
bias=False)
|
76 |
+
self.bn1 = nn.BatchNorm2d(128)
|
77 |
+
self.relu = nn.ReLU(inplace=True)
|
78 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
79 |
+
self.layer1 = self._make_layer(block, 128, layers[0])
|
80 |
+
self.layer2 = self._make_layer(block, 256, layers[1], stride=2)
|
81 |
+
self.layer3 = self._make_layer(block, 512, layers[2], stride=2)
|
82 |
+
self.layer4 = self._make_layer(block, 1024, layers[3], stride=2)
|
83 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
84 |
+
self.fc = nn.Linear(512 * block.expansion, 1000)
|
85 |
+
|
86 |
+
for m in self.modules():
|
87 |
+
if isinstance(m, nn.Conv2d):
|
88 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
89 |
+
elif isinstance(m, nn.BatchNorm2d):
|
90 |
+
nn.init.constant_(m.weight, 1)
|
91 |
+
nn.init.constant_(m.bias, 0)
|
92 |
+
|
93 |
+
# Zero-initialize the last BN in each residual branch,
|
94 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
95 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
96 |
+
if zero_init_residual:
|
97 |
+
for m in self.modules():
|
98 |
+
if isinstance(m, Bottleneck):
|
99 |
+
nn.init.constant_(m.bn3.weight, 0)
|
100 |
+
|
101 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
102 |
+
downsample = None
|
103 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
104 |
+
downsample = nn.Sequential(
|
105 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
106 |
+
nn.BatchNorm2d(planes * block.expansion),
|
107 |
+
)
|
108 |
+
|
109 |
+
layers = []
|
110 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
111 |
+
self.inplanes = planes * block.expansion
|
112 |
+
for _ in range(1, blocks):
|
113 |
+
layers.append(block(self.inplanes, planes))
|
114 |
+
|
115 |
+
return nn.Sequential(*layers)
|
116 |
+
|
117 |
+
|
118 |
+
def resnet101(pretrained=False, **kwargs):
|
119 |
+
"""Constructs a ResNet-101 model.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
123 |
+
"""
|
124 |
+
model = Backbone(Bottleneck, [3, 4, 23, 3], **kwargs)
|
125 |
+
if pretrained:
|
126 |
+
weights = model_zoo.load_url(model_urls['resnet101'])
|
127 |
+
|
128 |
+
for key in weights:
|
129 |
+
if key.split('.')[0] == 'fc':
|
130 |
+
weights[key] = weights[key].clone()
|
131 |
+
continue
|
132 |
+
weights[key] = torch.cat([weights[key].clone(), weights[key].clone()], dim=0)
|
133 |
+
|
134 |
+
model.load_state_dict(weights)
|
135 |
+
return model
|
136 |
+
|
model/base/chm.py
ADDED
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" 4D and 6D convolutional Hough matching layers """
|
2 |
+
|
3 |
+
from torch.nn.modules.conv import _ConvNd
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from common.logger import Logger
|
9 |
+
from . import chm_kernel
|
10 |
+
|
11 |
+
|
12 |
+
def fast4d(corr, kernel, bias=None):
|
13 |
+
r""" Optimized implementation of 4D convolution """
|
14 |
+
bsz, ch, srch, srcw, trgh, trgw = corr.size()
|
15 |
+
out_channels, _, kernel_size, kernel_size, kernel_size, kernel_size = kernel.size()
|
16 |
+
psz = kernel_size // 2
|
17 |
+
|
18 |
+
out_corr = torch.zeros((bsz, out_channels, srch, srcw, trgh, trgw))
|
19 |
+
corr = corr.transpose(1, 2).contiguous().view(bsz * srch, ch, srcw, trgh, trgw)
|
20 |
+
|
21 |
+
for pidx, k3d in enumerate(kernel.permute(2, 0, 1, 3, 4, 5)):
|
22 |
+
inter_corr = F.conv3d(corr, k3d, bias=None, stride=1, padding=psz)
|
23 |
+
inter_corr = inter_corr.view(bsz, srch, out_channels, srcw, trgh, trgw).transpose(1, 2).contiguous()
|
24 |
+
|
25 |
+
add_sid = max(psz - pidx, 0)
|
26 |
+
add_fid = min(srch, srch + psz - pidx)
|
27 |
+
slc_sid = max(pidx - psz, 0)
|
28 |
+
slc_fid = min(srch, srch - psz + pidx)
|
29 |
+
|
30 |
+
out_corr[:, :, add_sid:add_fid, :, :, :] += inter_corr[:, :, slc_sid:slc_fid, :, :, :]
|
31 |
+
|
32 |
+
if bias is not None:
|
33 |
+
out_corr += bias.view(1, out_channels, 1, 1, 1, 1)
|
34 |
+
|
35 |
+
return out_corr
|
36 |
+
|
37 |
+
|
38 |
+
def fast6d(corr, kernel, bias, diagonal_idx):
|
39 |
+
r""" Optimized implementation of 6D convolutional Hough matching
|
40 |
+
NOTE: this function only supports kernel size of (3, 3, 5, 5, 5, 5).
|
41 |
+
r"""
|
42 |
+
bsz, _, s6d, s6d, s4d, s4d, s4d, s4d = corr.size()
|
43 |
+
_, _, ks6d, ks6d, ks4d, ks4d, ks4d, ks4d = kernel.size()
|
44 |
+
corr = corr.permute(0, 2, 3, 1, 4, 5, 6, 7).contiguous().view(-1, 1, s4d, s4d, s4d, s4d)
|
45 |
+
kernel = kernel.view(-1, ks6d ** 2, ks4d, ks4d, ks4d, ks4d).transpose(0, 1)
|
46 |
+
corr = fast4d(corr, kernel).view(bsz, s6d * s6d, ks6d * ks6d, s4d, s4d, s4d, s4d)
|
47 |
+
corr = corr.view(bsz, s6d, s6d, ks6d, ks6d, s4d, s4d, s4d, s4d).transpose(2, 3).\
|
48 |
+
contiguous().view(-1, s6d * ks6d, s4d, s4d, s4d, s4d)
|
49 |
+
|
50 |
+
ndiag = s6d + (ks6d // 2) * 2
|
51 |
+
first_sum = []
|
52 |
+
for didx in diagonal_idx:
|
53 |
+
first_sum.append(corr[:, didx, :, :, :, :].sum(dim=1))
|
54 |
+
first_sum = torch.stack(first_sum).transpose(0, 1).view(bsz, s6d * ks6d, ndiag, s4d, s4d, s4d, s4d)
|
55 |
+
|
56 |
+
corr = []
|
57 |
+
for didx in diagonal_idx:
|
58 |
+
corr.append(first_sum[:, didx, :, :, :, :, :].sum(dim=1))
|
59 |
+
sidx = ks6d // 2
|
60 |
+
eidx = ndiag - sidx
|
61 |
+
corr = torch.stack(corr).transpose(0, 1)[:, sidx:eidx, sidx:eidx, :, :, :, :].unsqueeze(1).contiguous()
|
62 |
+
corr += bias.view(1, -1, 1, 1, 1, 1, 1, 1)
|
63 |
+
|
64 |
+
reverse_idx = torch.linspace(s6d * s6d - 1, 0, s6d * s6d).long()
|
65 |
+
corr = corr.view(bsz, 1, s6d * s6d, s4d, s4d, s4d, s4d)[:, :, reverse_idx, :, :, :, :].\
|
66 |
+
view(bsz, 1, s6d, s6d, s4d, s4d, s4d, s4d)
|
67 |
+
return corr
|
68 |
+
|
69 |
+
def init_param_idx4d(param_dict):
|
70 |
+
param_idx = []
|
71 |
+
for key in param_dict:
|
72 |
+
curr_offset = int(key.split('_')[-1])
|
73 |
+
param_idx.append(torch.tensor(param_dict[key]))
|
74 |
+
return param_idx
|
75 |
+
|
76 |
+
class CHM4d(_ConvNd):
|
77 |
+
r""" 4D convolutional Hough matching layer
|
78 |
+
NOTE: this function only supports in_channels=1 and out_channels=1.
|
79 |
+
r"""
|
80 |
+
def __init__(self, in_channels, out_channels, ksz4d, ktype, bias=True):
|
81 |
+
super(CHM4d, self).__init__(in_channels, out_channels, (ksz4d,) * 4,
|
82 |
+
(1,) * 4, (0,) * 4, (1,) * 4, False, (0,) * 4,
|
83 |
+
1, bias, padding_mode='zeros')
|
84 |
+
|
85 |
+
# Zero kernel initialization
|
86 |
+
self.zero_kernel4d = torch.zeros((in_channels, out_channels, ksz4d, ksz4d, ksz4d, ksz4d))
|
87 |
+
self.nkernels = in_channels * out_channels
|
88 |
+
|
89 |
+
# Initialize kernel indices
|
90 |
+
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
|
91 |
+
param_shared = param_dict4d is not None
|
92 |
+
|
93 |
+
if param_shared:
|
94 |
+
# Initialize the shared parameters (multiplied by the number of times being shared)
|
95 |
+
self.param_idx = init_param_idx4d(param_dict4d)
|
96 |
+
weights = torch.abs(torch.randn(len(self.param_idx) * self.nkernels)) * 1e-3
|
97 |
+
for weight, param_idx in zip(weights.sort()[0], self.param_idx):
|
98 |
+
weight *= len(param_idx)
|
99 |
+
self.weight = nn.Parameter(weights)
|
100 |
+
else: # full kernel initialziation
|
101 |
+
self.param_idx = None
|
102 |
+
self.weight = nn.Parameter(torch.abs(self.weight))
|
103 |
+
if bias: self.bias = nn.Parameter(torch.tensor(0.0))
|
104 |
+
Logger.info('(%s) # params in CHM 4D: %d' % (ktype, len(self.weight.view(-1))))
|
105 |
+
|
106 |
+
def forward(self, x):
|
107 |
+
kernel = self.init_kernel()
|
108 |
+
x = fast4d(x, kernel, self.bias)
|
109 |
+
return x
|
110 |
+
|
111 |
+
def init_kernel(self):
|
112 |
+
# Initialize CHM kernel (divided by the number of times being shared)
|
113 |
+
ksz = self.kernel_size[-1]
|
114 |
+
if self.param_idx is None:
|
115 |
+
kernel = self.weight
|
116 |
+
else:
|
117 |
+
kernel = torch.zeros_like(self.zero_kernel4d)
|
118 |
+
for idx, pdx in enumerate(self.param_idx):
|
119 |
+
kernel = kernel.view(-1, ksz, ksz, ksz, ksz)
|
120 |
+
for jdx, kernel_single in enumerate(kernel):
|
121 |
+
weight = self.weight[idx + jdx * len(self.param_idx)].repeat(len(pdx)) / len(pdx)
|
122 |
+
kernel_single.view(-1)[pdx] += weight
|
123 |
+
kernel = kernel.view(self.in_channels, self.out_channels, ksz, ksz, ksz, ksz)
|
124 |
+
return kernel
|
125 |
+
|
126 |
+
|
127 |
+
class CHM6d(_ConvNd):
|
128 |
+
r""" 6D convolutional Hough matching layer with kernel (3, 3, 5, 5, 5, 5)
|
129 |
+
NOTE: this function only supports in_channels=1 and out_channels=1.
|
130 |
+
r"""
|
131 |
+
def __init__(self, in_channels, out_channels, ksz6d, ksz4d, ktype):
|
132 |
+
kernel_size = (ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d)
|
133 |
+
super(CHM6d, self).__init__(in_channels, out_channels, kernel_size, (1,) * 6,
|
134 |
+
(0,) * 6, (1,) * 6, False, (0,) * 6,
|
135 |
+
1, bias=True, padding_mode='zeros')
|
136 |
+
|
137 |
+
# Zero kernel initialization
|
138 |
+
self.zero_kernel4d = torch.zeros((ksz4d, ksz4d, ksz4d, ksz4d))
|
139 |
+
self.zero_kernel6d = torch.zeros((ksz6d, ksz6d, ksz4d, ksz4d, ksz4d, ksz4d))
|
140 |
+
self.nkernels = in_channels * out_channels
|
141 |
+
|
142 |
+
# Initialize kernel indices
|
143 |
+
# Indices in scale-space where 4D convolutions are performed (3 by 3 scale-space)
|
144 |
+
self.diagonal_idx = [torch.tensor(x) for x in [[6], [3, 7], [0, 4, 8], [1, 5], [2]]]
|
145 |
+
param_dict4d = chm_kernel.KernelGenerator(ksz4d, ktype).generate()
|
146 |
+
param_shared = param_dict4d is not None
|
147 |
+
|
148 |
+
if param_shared: # psi & iso kernel initialization
|
149 |
+
if ktype == 'psi':
|
150 |
+
self.param_dict6d = [[4], [0, 8], [2, 6], [1, 3, 5, 7]]
|
151 |
+
elif ktype == 'iso':
|
152 |
+
self.param_dict6d = [[0, 4, 8], [2, 6], [1, 3, 5, 7]]
|
153 |
+
self.param_dict6d = [torch.tensor(i) for i in self.param_dict6d]
|
154 |
+
|
155 |
+
# Initialize the shared parameters (multiplied by the number of times being shared)
|
156 |
+
self.param_idx = init_param_idx4d(param_dict4d)
|
157 |
+
self.param = []
|
158 |
+
for param_dict6d in self.param_dict6d:
|
159 |
+
weights = torch.abs(torch.randn(len(self.param_idx))) * 1e-3
|
160 |
+
for weight, param_idx in zip(weights, self.param_idx):
|
161 |
+
weight *= (len(param_idx) * len(param_dict6d))
|
162 |
+
self.param.append(nn.Parameter(weights))
|
163 |
+
self.param = nn.ParameterList(self.param)
|
164 |
+
else: # full kernel initialziation
|
165 |
+
self.param_idx = None
|
166 |
+
self.param = nn.Parameter(torch.abs(self.weight) * 1e-3)
|
167 |
+
Logger.info('(%s) # params in CHM 6D: %d' % (ktype, sum([len(x.view(-1)) for x in self.param])))
|
168 |
+
self.weight = None
|
169 |
+
|
170 |
+
def forward(self, corr):
|
171 |
+
kernel = self.init_kernel()
|
172 |
+
corr = fast6d(corr, kernel, self.bias, self.diagonal_idx)
|
173 |
+
return corr
|
174 |
+
|
175 |
+
def init_kernel(self):
|
176 |
+
# Initialize CHM kernel (divided by the number of times being shared)
|
177 |
+
if self.param_idx is None:
|
178 |
+
return self.param
|
179 |
+
|
180 |
+
kernel6d = torch.zeros_like(self.zero_kernel6d)
|
181 |
+
for idx, (param, param_dict6d) in enumerate(zip(self.param, self.param_dict6d)):
|
182 |
+
ksz4d = self.kernel_size[-1]
|
183 |
+
kernel4d = torch.zeros_like(self.zero_kernel4d)
|
184 |
+
for jdx, pdx in enumerate(self.param_idx):
|
185 |
+
kernel4d.view(-1)[pdx] += ((param[jdx] / len(pdx)) / len(param_dict6d))
|
186 |
+
kernel6d.view(-1, ksz4d, ksz4d, ksz4d, ksz4d)[param_dict6d] += kernel4d.view(ksz4d, ksz4d, ksz4d, ksz4d)
|
187 |
+
kernel6d = kernel6d.unsqueeze(0).unsqueeze(0)
|
188 |
+
|
189 |
+
return kernel6d
|
190 |
+
|
model/base/chm_kernel.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" CHM 4D kernel (psi, iso, and full) generator """
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from .geometry import Geometry
|
6 |
+
|
7 |
+
|
8 |
+
class KernelGenerator:
|
9 |
+
def __init__(self, ksz, ktype):
|
10 |
+
self.ksz = ksz
|
11 |
+
self.idx4d = Geometry.init_idx4d(ksz)
|
12 |
+
self.kernel = torch.zeros((ksz, ksz, ksz, ksz))
|
13 |
+
self.center = (ksz // 2, ksz // 2)
|
14 |
+
self.ktype = ktype
|
15 |
+
|
16 |
+
def quadrant(self, crd):
|
17 |
+
if crd[0] < self.center[0]:
|
18 |
+
horz_quad = -1
|
19 |
+
elif crd[0] < self.center[0]:
|
20 |
+
horz_quad = 1
|
21 |
+
else:
|
22 |
+
horz_quad = 0
|
23 |
+
|
24 |
+
if crd[1] < self.center[1]:
|
25 |
+
vert_quad = -1
|
26 |
+
elif crd[1] < self.center[1]:
|
27 |
+
vert_quad = 1
|
28 |
+
else:
|
29 |
+
vert_quad = 0
|
30 |
+
|
31 |
+
return horz_quad, vert_quad
|
32 |
+
|
33 |
+
def generate(self):
|
34 |
+
return None if self.ktype == 'full' else self.generate_chm_kernel()
|
35 |
+
|
36 |
+
def generate_chm_kernel(self):
|
37 |
+
param_dict = {}
|
38 |
+
for idx in self.idx4d:
|
39 |
+
src_i, src_j, trg_i, trg_j = idx
|
40 |
+
d_tail = Geometry.get_distance((src_i, src_j), self.center)
|
41 |
+
d_head = Geometry.get_distance((trg_i, trg_j), self.center)
|
42 |
+
d_off = Geometry.get_distance((src_i, src_j), (trg_i, trg_j))
|
43 |
+
horz_quad, vert_quad = self.quadrant((src_j, src_i))
|
44 |
+
|
45 |
+
src_crd = (src_i, src_j)
|
46 |
+
trg_crd = (trg_i, trg_j)
|
47 |
+
|
48 |
+
key = self.build_key(horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off)
|
49 |
+
coord1d = Geometry.get_coord1d((src_i, src_j, trg_i, trg_j), self.ksz)
|
50 |
+
|
51 |
+
if param_dict.get(key) is None: param_dict[key] = []
|
52 |
+
param_dict[key].append(coord1d)
|
53 |
+
|
54 |
+
return param_dict
|
55 |
+
|
56 |
+
def build_key(self, horz_quad, vert_quad, d_head, d_tail, src_crd, trg_crd, d_off):
|
57 |
+
|
58 |
+
if self.ktype == 'iso':
|
59 |
+
return '%d' % d_off
|
60 |
+
elif self.ktype == 'psi':
|
61 |
+
d_max = max(d_head, d_tail)
|
62 |
+
d_min = min(d_head, d_tail)
|
63 |
+
return '%d_%d_%d' % (d_max, d_min, d_off)
|
64 |
+
else:
|
65 |
+
raise Exception('not implemented.')
|
66 |
+
|
model/base/correlation.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Provides functions that creates/manipulates correlation matrices """
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
from torch.nn.functional import interpolate as resize
|
6 |
+
import torch
|
7 |
+
|
8 |
+
from .geometry import Geometry
|
9 |
+
|
10 |
+
|
11 |
+
class Correlation:
|
12 |
+
|
13 |
+
@classmethod
|
14 |
+
def mutual_nn_filter(cls, correlation_matrix, eps=1e-30):
|
15 |
+
r""" Mutual nearest neighbor filtering (Rocco et al. NeurIPS'18 )"""
|
16 |
+
corr_src_max = torch.max(correlation_matrix, dim=2, keepdim=True)[0]
|
17 |
+
corr_trg_max = torch.max(correlation_matrix, dim=1, keepdim=True)[0]
|
18 |
+
corr_src_max[corr_src_max == 0] += eps
|
19 |
+
corr_trg_max[corr_trg_max == 0] += eps
|
20 |
+
|
21 |
+
corr_src = correlation_matrix / corr_src_max
|
22 |
+
corr_trg = correlation_matrix / corr_trg_max
|
23 |
+
|
24 |
+
return correlation_matrix * (corr_src * corr_trg)
|
25 |
+
|
26 |
+
@classmethod
|
27 |
+
def build_correlation6d(self, src_feat, trg_feat, scales, conv2ds):
|
28 |
+
r""" Build 6-dimensional correlation tensor """
|
29 |
+
|
30 |
+
bsz, _, side, side = src_feat.size()
|
31 |
+
|
32 |
+
# Construct feature pairs with multiple scales
|
33 |
+
_src_feats = []
|
34 |
+
_trg_feats = []
|
35 |
+
for scale, conv in zip(scales, conv2ds):
|
36 |
+
s = (round(side * math.sqrt(scale)),) * 2
|
37 |
+
_src_feat = conv(resize(src_feat, s, mode='bilinear', align_corners=True))
|
38 |
+
_trg_feat = conv(resize(trg_feat, s, mode='bilinear', align_corners=True))
|
39 |
+
_src_feats.append(_src_feat)
|
40 |
+
_trg_feats.append(_trg_feat)
|
41 |
+
|
42 |
+
# Build multiple 4-dimensional correlation tensor
|
43 |
+
corr6d = []
|
44 |
+
for src_feat in _src_feats:
|
45 |
+
ch = src_feat.size(1)
|
46 |
+
|
47 |
+
src_side = src_feat.size(-1)
|
48 |
+
src_feat = src_feat.view(bsz, ch, -1).transpose(1, 2)
|
49 |
+
src_norm = src_feat.norm(p=2, dim=2, keepdim=True)
|
50 |
+
|
51 |
+
for trg_feat in _trg_feats:
|
52 |
+
trg_side = trg_feat.size(-1)
|
53 |
+
trg_feat = trg_feat.view(bsz, ch, -1)
|
54 |
+
trg_norm = trg_feat.norm(p=2, dim=1, keepdim=True)
|
55 |
+
|
56 |
+
correlation = torch.bmm(src_feat, trg_feat) / torch.bmm(src_norm, trg_norm)
|
57 |
+
correlation = correlation.view(bsz, src_side, src_side, trg_side, trg_side).contiguous()
|
58 |
+
corr6d.append(correlation)
|
59 |
+
|
60 |
+
# Resize the spatial sizes of the 4D tensors to the same size
|
61 |
+
for idx, correlation in enumerate(corr6d):
|
62 |
+
corr6d[idx] = Geometry.interpolate4d(correlation, [side, side])
|
63 |
+
|
64 |
+
# Build 6-dimensional correlation tensor
|
65 |
+
corr6d = torch.stack(corr6d).view(len(scales), len(scales),
|
66 |
+
bsz, side, side, side, side).permute(2, 0, 1, 3, 4, 5, 6)
|
67 |
+
return corr6d.clamp(min=0)
|
68 |
+
|
model/base/geometry.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Provides functions that manipulate boxes and points """
|
2 |
+
|
3 |
+
import math
|
4 |
+
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torch
|
7 |
+
|
8 |
+
|
9 |
+
class Geometry(object):
|
10 |
+
|
11 |
+
@classmethod
|
12 |
+
def initialize(cls, img_size):
|
13 |
+
cls.img_size = img_size
|
14 |
+
|
15 |
+
cls.spatial_side = int(img_size / 8)
|
16 |
+
norm_grid1d = torch.linspace(-1, 1, cls.spatial_side)
|
17 |
+
|
18 |
+
cls.norm_grid_x = norm_grid1d.view(1, -1).repeat(cls.spatial_side, 1).view(1, 1, -1)
|
19 |
+
cls.norm_grid_y = norm_grid1d.view(-1, 1).repeat(1, cls.spatial_side).view(1, 1, -1)
|
20 |
+
cls.grid = torch.stack(list(reversed(torch.meshgrid(norm_grid1d, norm_grid1d)))).permute(1, 2, 0)
|
21 |
+
|
22 |
+
cls.feat_idx = torch.arange(0, cls.spatial_side).float()
|
23 |
+
|
24 |
+
@classmethod
|
25 |
+
def normalize_kps(cls, kps):
|
26 |
+
kps = kps.clone().detach()
|
27 |
+
kps[kps != -2] -= (cls.img_size // 2)
|
28 |
+
kps[kps != -2] /= (cls.img_size // 2)
|
29 |
+
return kps
|
30 |
+
|
31 |
+
@classmethod
|
32 |
+
def unnormalize_kps(cls, kps):
|
33 |
+
kps = kps.clone().detach()
|
34 |
+
kps[kps != -2] *= (cls.img_size // 2)
|
35 |
+
kps[kps != -2] += (cls.img_size // 2)
|
36 |
+
return kps
|
37 |
+
|
38 |
+
@classmethod
|
39 |
+
def attentive_indexing(cls, kps, thres=0.1):
|
40 |
+
r"""kps: normalized keypoints x, y (N, 2)
|
41 |
+
returns attentive index map(N, spatial_side, spatial_side)
|
42 |
+
"""
|
43 |
+
nkps = kps.size(0)
|
44 |
+
kps = kps.view(nkps, 1, 1, 2)
|
45 |
+
|
46 |
+
eps = 1e-5
|
47 |
+
attmap = (cls.grid.unsqueeze(0).repeat(nkps, 1, 1, 1) - kps).pow(2).sum(dim=3)
|
48 |
+
attmap = (attmap + eps).pow(0.5)
|
49 |
+
attmap = (thres - attmap).clamp(min=0).view(nkps, -1)
|
50 |
+
attmap = attmap / attmap.sum(dim=1, keepdim=True)
|
51 |
+
attmap = attmap.view(nkps, cls.spatial_side, cls.spatial_side)
|
52 |
+
|
53 |
+
return attmap
|
54 |
+
|
55 |
+
@classmethod
|
56 |
+
def apply_gaussian_kernel(cls, corr, sigma=17):
|
57 |
+
bsz, side, side = corr.size()
|
58 |
+
|
59 |
+
center = corr.max(dim=2)[1]
|
60 |
+
center_y = center // cls.spatial_side
|
61 |
+
center_x = center % cls.spatial_side
|
62 |
+
|
63 |
+
y = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_y.size(1), 1) - center_y.unsqueeze(2)
|
64 |
+
x = cls.feat_idx.view(1, 1, cls.spatial_side).repeat(bsz, center_x.size(1), 1) - center_x.unsqueeze(2)
|
65 |
+
|
66 |
+
y = y.unsqueeze(3).repeat(1, 1, 1, cls.spatial_side)
|
67 |
+
x = x.unsqueeze(2).repeat(1, 1, cls.spatial_side, 1)
|
68 |
+
|
69 |
+
gauss_kernel = torch.exp(-(x.pow(2) + y.pow(2)) / (2 * sigma ** 2))
|
70 |
+
filtered_corr = gauss_kernel * corr.view(bsz, -1, cls.spatial_side, cls.spatial_side)
|
71 |
+
filtered_corr = filtered_corr.view(bsz, side, side)
|
72 |
+
|
73 |
+
return filtered_corr
|
74 |
+
|
75 |
+
@classmethod
|
76 |
+
def transfer_kps(cls, confidence_ts, src_kps, n_pts, normalized):
|
77 |
+
r""" Transfer keypoints by weighted average """
|
78 |
+
|
79 |
+
if not normalized:
|
80 |
+
src_kps = Geometry.normalize_kps(src_kps)
|
81 |
+
confidence_ts = cls.apply_gaussian_kernel(confidence_ts)
|
82 |
+
|
83 |
+
pdf = F.softmax(confidence_ts, dim=2)
|
84 |
+
prd_x = (pdf * cls.norm_grid_x).sum(dim=2)
|
85 |
+
prd_y = (pdf * cls.norm_grid_y).sum(dim=2)
|
86 |
+
|
87 |
+
prd_kps = []
|
88 |
+
for idx, (x, y, src_kp, np) in enumerate(zip(prd_x, prd_y, src_kps, n_pts)):
|
89 |
+
max_pts = src_kp.size()[1]
|
90 |
+
prd_xy = torch.stack([x, y]).t()
|
91 |
+
|
92 |
+
src_kp = src_kp[:, :np].t()
|
93 |
+
attmap = cls.attentive_indexing(src_kp).view(np, -1)
|
94 |
+
prd_kp = (prd_xy.unsqueeze(0) * attmap.unsqueeze(-1)).sum(dim=1).t()
|
95 |
+
pads = (torch.zeros((2, max_pts - np)) - 2)
|
96 |
+
prd_kp = torch.cat([prd_kp, pads], dim=1)
|
97 |
+
prd_kps.append(prd_kp)
|
98 |
+
|
99 |
+
return torch.stack(prd_kps)
|
100 |
+
|
101 |
+
@staticmethod
|
102 |
+
def get_coord1d(coord4d, ksz):
|
103 |
+
i, j, k, l = coord4d
|
104 |
+
coord1d = i * (ksz ** 3) + j * (ksz ** 2) + k * (ksz) + l
|
105 |
+
return coord1d
|
106 |
+
|
107 |
+
@staticmethod
|
108 |
+
def get_distance(coord1, coord2):
|
109 |
+
delta_y = int(math.pow(coord1[0] - coord2[0], 2))
|
110 |
+
delta_x = int(math.pow(coord1[1] - coord2[1], 2))
|
111 |
+
dist = delta_y + delta_x
|
112 |
+
return dist
|
113 |
+
|
114 |
+
@staticmethod
|
115 |
+
def interpolate4d(tensor4d, size):
|
116 |
+
bsz, h1, w1, h2, w2 = tensor4d.size()
|
117 |
+
tensor4d = tensor4d.view(bsz, h1, w1, -1).permute(0, 3, 1, 2)
|
118 |
+
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
|
119 |
+
tensor4d = tensor4d.view(bsz, h2, w2, -1).permute(0, 3, 1, 2)
|
120 |
+
tensor4d = F.interpolate(tensor4d, size, mode='bilinear', align_corners=True)
|
121 |
+
tensor4d = tensor4d.view(bsz, size[0], size[0], size[0], size[0])
|
122 |
+
|
123 |
+
return tensor4d
|
124 |
+
@staticmethod
|
125 |
+
def init_idx4d(ksz):
|
126 |
+
i0 = torch.arange(0, ksz).repeat(ksz ** 3)
|
127 |
+
i1 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz).view(-1).repeat(ksz ** 2)
|
128 |
+
i2 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 2).view(-1).repeat(ksz)
|
129 |
+
i3 = torch.arange(0, ksz).unsqueeze(1).repeat(1, ksz ** 3).view(-1)
|
130 |
+
idx4d = torch.stack([i3, i2, i1, i0]).t().numpy()
|
131 |
+
|
132 |
+
return idx4d
|
133 |
+
|
model/chmlearner.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Conovlutional Hough matching layers """
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from .base.correlation import Correlation
|
7 |
+
from .base.geometry import Geometry
|
8 |
+
from .base.chm import CHM4d, CHM6d
|
9 |
+
|
10 |
+
|
11 |
+
class CHMLearner(nn.Module):
|
12 |
+
|
13 |
+
def __init__(self, ktype, feat_dim):
|
14 |
+
super(CHMLearner, self).__init__()
|
15 |
+
|
16 |
+
# Scale-wise feature transformation
|
17 |
+
self.scales = [0.5, 1, 2]
|
18 |
+
self.conv2ds = nn.ModuleList([nn.Conv2d(feat_dim, feat_dim // 4, kernel_size=3, padding=1, bias=False) for _ in self.scales])
|
19 |
+
|
20 |
+
# CHM layers
|
21 |
+
ksz_translation = 5
|
22 |
+
ksz_scale = 3
|
23 |
+
self.chm6d = CHM6d(1, 1, ksz_scale, ksz_translation, ktype)
|
24 |
+
self.chm4d = CHM4d(1, 1, ksz_translation, ktype, bias=True)
|
25 |
+
|
26 |
+
# Activations
|
27 |
+
self.relu = nn.ReLU(inplace=True)
|
28 |
+
self.sigmoid = nn.Sigmoid()
|
29 |
+
self.softplus = nn.Softplus()
|
30 |
+
|
31 |
+
def forward(self, src_feat, trg_feat):
|
32 |
+
|
33 |
+
corr = Correlation.build_correlation6d(src_feat, trg_feat, self.scales, self.conv2ds).unsqueeze(1)
|
34 |
+
bsz, ch, s, s, h, w, h, w = corr.size()
|
35 |
+
|
36 |
+
# CHM layer (6D)
|
37 |
+
corr = self.chm6d(corr)
|
38 |
+
corr = self.sigmoid(corr)
|
39 |
+
|
40 |
+
# Scale-space maxpool
|
41 |
+
corr = corr.view(bsz, -1, h, w, h, w).max(dim=1)[0]
|
42 |
+
corr = Geometry.interpolate4d(corr, [h * 2, w * 2]).unsqueeze(1)
|
43 |
+
|
44 |
+
# CHM layer (4D)
|
45 |
+
corr = self.chm4d(corr).squeeze(1)
|
46 |
+
|
47 |
+
# To ensure non-negative vote scores & soft cyclic constraints
|
48 |
+
corr = self.softplus(corr)
|
49 |
+
corr = Correlation.mutual_nn_filter(corr.view(bsz, corr.size(-1) ** 2, corr.size(-1) ** 2).contiguous())
|
50 |
+
|
51 |
+
return corr
|
52 |
+
|
model/chmnet.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
r""" Convolutional Hough Matching Networks """
|
2 |
+
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from . import chmlearner as chmlearner
|
7 |
+
from .base import backbone
|
8 |
+
|
9 |
+
|
10 |
+
class CHMNet(nn.Module):
|
11 |
+
def __init__(self, ktype):
|
12 |
+
super(CHMNet, self).__init__()
|
13 |
+
|
14 |
+
self.backbone = backbone.resnet101(pretrained=True)
|
15 |
+
self.learner = chmlearner.CHMLearner(ktype, feat_dim=1024)
|
16 |
+
|
17 |
+
def forward(self, src_img, trg_img):
|
18 |
+
src_feat, trg_feat = self.extract_features(src_img, trg_img)
|
19 |
+
correlation = self.learner(src_feat, trg_feat)
|
20 |
+
return correlation
|
21 |
+
|
22 |
+
def extract_features(self, src_img, trg_img):
|
23 |
+
feat = self.backbone.conv1.forward(torch.cat([src_img, trg_img], dim=1))
|
24 |
+
feat = self.backbone.bn1.forward(feat)
|
25 |
+
feat = self.backbone.relu.forward(feat)
|
26 |
+
feat = self.backbone.maxpool.forward(feat)
|
27 |
+
|
28 |
+
for idx in range(1, 5):
|
29 |
+
feat = self.backbone.__getattr__('layer%d' % idx)(feat)
|
30 |
+
|
31 |
+
if idx == 3:
|
32 |
+
src_feat = feat.narrow(1, 0, feat.size(1) // 2).clone()
|
33 |
+
trg_feat = feat.narrow(1, feat.size(1) // 2, feat.size(1) // 2).clone()
|
34 |
+
return src_feat, trg_feat
|
35 |
+
|
36 |
+
def training_objective(cls, prd_kps, trg_kps, npts):
|
37 |
+
l2dist = (prd_kps - trg_kps).pow(2).sum(dim=1)
|
38 |
+
loss = []
|
39 |
+
for dist, npt in zip(l2dist, npts):
|
40 |
+
loss.append(dist[:npt].mean())
|
41 |
+
return torch.stack(loss).mean()
|
42 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio==2.4.5
|
2 |
+
matplotlib==3.4.3
|
3 |
+
numpy==1.21.2
|
4 |
+
pandas==1.3.4
|
5 |
+
Pillow==8.4.0
|
6 |
+
requests==2.26.0
|
7 |
+
scipy==1.7.1
|
8 |
+
tensorboardX==2.4.1
|
9 |
+
torch==1.10.0
|
10 |
+
torchvision==0.11.1
|