Spaces:
Running
Running
Vincentqyw
commited on
Commit
·
2947428
1
Parent(s):
a9f1fc6
fix:roma
Browse files- common/utils.py +4 -4
- third_party/Roma/roma/models/encoders.py +33 -81
- third_party/Roma/roma/models/matcher.py +145 -267
common/utils.py
CHANGED
|
@@ -49,7 +49,7 @@ def gen_examples():
|
|
| 49 |
"topicfm",
|
| 50 |
"superpoint+superglue",
|
| 51 |
"disk+dualsoftmax",
|
| 52 |
-
"
|
| 53 |
]
|
| 54 |
|
| 55 |
def gen_images_pairs(path: str, count: int = 5):
|
|
@@ -452,12 +452,11 @@ ransac_zoo = {
|
|
| 452 |
|
| 453 |
# Matchers collections
|
| 454 |
matcher_zoo = {
|
| 455 |
-
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
| 456 |
-
"sold2": {"config": match_dense.confs["sold2"], "dense": True},
|
| 457 |
# 'dedode-sparse': {
|
| 458 |
# 'config': match_dense.confs['dedode_sparse'],
|
| 459 |
# 'dense': True # dense mode, we need 2 images
|
| 460 |
# },
|
|
|
|
| 461 |
"loftr": {"config": match_dense.confs["loftr"], "dense": True},
|
| 462 |
"topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
|
| 463 |
"aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
|
|
@@ -556,6 +555,7 @@ matcher_zoo = {
|
|
| 556 |
"config_feature": extract_features.confs["sift"],
|
| 557 |
"dense": False,
|
| 558 |
},
|
| 559 |
-
|
|
|
|
| 560 |
# "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
|
| 561 |
}
|
|
|
|
| 49 |
"topicfm",
|
| 50 |
"superpoint+superglue",
|
| 51 |
"disk+dualsoftmax",
|
| 52 |
+
"roma",
|
| 53 |
]
|
| 54 |
|
| 55 |
def gen_images_pairs(path: str, count: int = 5):
|
|
|
|
| 452 |
|
| 453 |
# Matchers collections
|
| 454 |
matcher_zoo = {
|
|
|
|
|
|
|
| 455 |
# 'dedode-sparse': {
|
| 456 |
# 'config': match_dense.confs['dedode_sparse'],
|
| 457 |
# 'dense': True # dense mode, we need 2 images
|
| 458 |
# },
|
| 459 |
+
"roma": {"config": match_dense.confs["roma"], "dense": True},
|
| 460 |
"loftr": {"config": match_dense.confs["loftr"], "dense": True},
|
| 461 |
"topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
|
| 462 |
"aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
|
|
|
|
| 555 |
"config_feature": extract_features.confs["sift"],
|
| 556 |
"dense": False,
|
| 557 |
},
|
| 558 |
+
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
| 559 |
+
"sold2": {"config": match_dense.confs["sold2"], "dense": True},
|
| 560 |
# "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
|
| 561 |
}
|
third_party/Roma/roma/models/encoders.py
CHANGED
|
@@ -6,59 +6,37 @@ import torch.nn.functional as F
|
|
| 6 |
import torchvision.models as tvm
|
| 7 |
import gc
|
| 8 |
|
| 9 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
-
|
| 11 |
|
| 12 |
class ResNet50(nn.Module):
|
| 13 |
-
def __init__(
|
| 14 |
-
self,
|
| 15 |
-
pretrained=False,
|
| 16 |
-
high_res=False,
|
| 17 |
-
weights=None,
|
| 18 |
-
dilation=None,
|
| 19 |
-
freeze_bn=True,
|
| 20 |
-
anti_aliased=False,
|
| 21 |
-
early_exit=False,
|
| 22 |
-
amp=False,
|
| 23 |
-
) -> None:
|
| 24 |
super().__init__()
|
| 25 |
if dilation is None:
|
| 26 |
-
dilation = [False,
|
| 27 |
if anti_aliased:
|
| 28 |
pass
|
| 29 |
else:
|
| 30 |
if weights is not None:
|
| 31 |
-
self.net = tvm.resnet50(
|
| 32 |
-
weights=weights, replace_stride_with_dilation=dilation
|
| 33 |
-
)
|
| 34 |
else:
|
| 35 |
-
self.net = tvm.resnet50(
|
| 36 |
-
|
| 37 |
-
)
|
| 38 |
-
|
| 39 |
self.high_res = high_res
|
| 40 |
self.freeze_bn = freeze_bn
|
| 41 |
self.early_exit = early_exit
|
| 42 |
self.amp = amp
|
| 43 |
-
if torch.cuda.
|
| 44 |
-
if torch.cuda.is_bf16_supported():
|
| 45 |
-
self.amp_dtype = torch.bfloat16
|
| 46 |
-
else:
|
| 47 |
-
self.amp_dtype = torch.float16
|
| 48 |
-
else:
|
| 49 |
-
self.amp_dtype = torch.float32
|
| 50 |
|
| 51 |
def forward(self, x, **kwargs):
|
| 52 |
-
with torch.autocast(
|
| 53 |
net = self.net
|
| 54 |
-
feats = {1:
|
| 55 |
x = net.conv1(x)
|
| 56 |
x = net.bn1(x)
|
| 57 |
x = net.relu(x)
|
| 58 |
-
feats[2] = x
|
| 59 |
x = net.maxpool(x)
|
| 60 |
x = net.layer1(x)
|
| 61 |
-
feats[4] = x
|
| 62 |
x = net.layer2(x)
|
| 63 |
feats[8] = x
|
| 64 |
if self.early_exit:
|
|
@@ -77,48 +55,35 @@ class ResNet50(nn.Module):
|
|
| 77 |
m.eval()
|
| 78 |
pass
|
| 79 |
|
| 80 |
-
|
| 81 |
class VGG19(nn.Module):
|
| 82 |
-
def __init__(self, pretrained=False, amp=False) -> None:
|
| 83 |
super().__init__()
|
| 84 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 85 |
self.amp = amp
|
| 86 |
-
if torch.cuda.
|
| 87 |
-
if torch.cuda.is_bf16_supported():
|
| 88 |
-
self.amp_dtype = torch.bfloat16
|
| 89 |
-
else:
|
| 90 |
-
self.amp_dtype = torch.float16
|
| 91 |
-
else:
|
| 92 |
-
self.amp_dtype = torch.float32
|
| 93 |
|
| 94 |
def forward(self, x, **kwargs):
|
| 95 |
-
with torch.autocast(
|
| 96 |
feats = {}
|
| 97 |
scale = 1
|
| 98 |
for layer in self.layers:
|
| 99 |
if isinstance(layer, nn.MaxPool2d):
|
| 100 |
feats[scale] = x
|
| 101 |
-
scale = scale
|
| 102 |
x = layer(x)
|
| 103 |
return feats
|
| 104 |
|
| 105 |
-
|
| 106 |
class CNNandDinov2(nn.Module):
|
| 107 |
-
def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None):
|
| 108 |
super().__init__()
|
| 109 |
if dinov2_weights is None:
|
| 110 |
-
dinov2_weights = torch.hub.load_state_dict_from_url(
|
| 111 |
-
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
|
| 112 |
-
map_location="cpu",
|
| 113 |
-
)
|
| 114 |
from .transformer import vit_large
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
ffn_layer="mlp",
|
| 121 |
-
block_chunks=0,
|
| 122 |
)
|
| 123 |
|
| 124 |
dinov2_vitl14 = vit_large(**vit_kwargs).eval()
|
|
@@ -129,38 +94,25 @@ class CNNandDinov2(nn.Module):
|
|
| 129 |
else:
|
| 130 |
self.cnn = VGG19(**cnn_kwargs)
|
| 131 |
self.amp = amp
|
| 132 |
-
if torch.cuda.
|
| 133 |
-
if torch.cuda.is_bf16_supported():
|
| 134 |
-
self.amp_dtype = torch.bfloat16
|
| 135 |
-
else:
|
| 136 |
-
self.amp_dtype = torch.float16
|
| 137 |
-
else:
|
| 138 |
-
self.amp_dtype = torch.float32
|
| 139 |
if self.amp:
|
| 140 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
| 141 |
-
self.dinov2_vitl14 = [dinov2_vitl14]
|
| 142 |
-
|
|
|
|
| 143 |
def train(self, mode: bool = True):
|
| 144 |
return self.cnn.train(mode)
|
| 145 |
-
|
| 146 |
-
def forward(self, x, upsample=False):
|
| 147 |
-
B,
|
| 148 |
feature_pyramid = self.cnn(x)
|
| 149 |
-
|
| 150 |
if not upsample:
|
| 151 |
with torch.no_grad():
|
| 152 |
if self.dinov2_vitl14[0].device != x.device:
|
| 153 |
-
self.dinov2_vitl14[0] = (
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(
|
| 157 |
-
x.to(self.amp_dtype)
|
| 158 |
-
)
|
| 159 |
-
features_16 = (
|
| 160 |
-
dinov2_features_16["x_norm_patchtokens"]
|
| 161 |
-
.permute(0, 2, 1)
|
| 162 |
-
.reshape(B, 1024, H // 14, W // 14)
|
| 163 |
-
)
|
| 164 |
del dinov2_features_16
|
| 165 |
feature_pyramid[16] = features_16
|
| 166 |
-
return feature_pyramid
|
|
|
|
| 6 |
import torchvision.models as tvm
|
| 7 |
import gc
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class ResNet50(nn.Module):
|
| 11 |
+
def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
super().__init__()
|
| 13 |
if dilation is None:
|
| 14 |
+
dilation = [False,False,False]
|
| 15 |
if anti_aliased:
|
| 16 |
pass
|
| 17 |
else:
|
| 18 |
if weights is not None:
|
| 19 |
+
self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
|
|
|
|
|
|
|
| 20 |
else:
|
| 21 |
+
self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
|
| 22 |
+
|
|
|
|
|
|
|
| 23 |
self.high_res = high_res
|
| 24 |
self.freeze_bn = freeze_bn
|
| 25 |
self.early_exit = early_exit
|
| 26 |
self.amp = amp
|
| 27 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 28 |
|
| 29 |
def forward(self, x, **kwargs):
|
| 30 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
| 31 |
net = self.net
|
| 32 |
+
feats = {1:x}
|
| 33 |
x = net.conv1(x)
|
| 34 |
x = net.bn1(x)
|
| 35 |
x = net.relu(x)
|
| 36 |
+
feats[2] = x
|
| 37 |
x = net.maxpool(x)
|
| 38 |
x = net.layer1(x)
|
| 39 |
+
feats[4] = x
|
| 40 |
x = net.layer2(x)
|
| 41 |
feats[8] = x
|
| 42 |
if self.early_exit:
|
|
|
|
| 55 |
m.eval()
|
| 56 |
pass
|
| 57 |
|
|
|
|
| 58 |
class VGG19(nn.Module):
|
| 59 |
+
def __init__(self, pretrained=False, amp = False) -> None:
|
| 60 |
super().__init__()
|
| 61 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
| 62 |
self.amp = amp
|
| 63 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
|
| 65 |
def forward(self, x, **kwargs):
|
| 66 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
| 67 |
feats = {}
|
| 68 |
scale = 1
|
| 69 |
for layer in self.layers:
|
| 70 |
if isinstance(layer, nn.MaxPool2d):
|
| 71 |
feats[scale] = x
|
| 72 |
+
scale = scale*2
|
| 73 |
x = layer(x)
|
| 74 |
return feats
|
| 75 |
|
|
|
|
| 76 |
class CNNandDinov2(nn.Module):
|
| 77 |
+
def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
|
| 78 |
super().__init__()
|
| 79 |
if dinov2_weights is None:
|
| 80 |
+
dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
|
|
|
|
|
|
|
|
|
|
| 81 |
from .transformer import vit_large
|
| 82 |
+
vit_kwargs = dict(img_size= 518,
|
| 83 |
+
patch_size= 14,
|
| 84 |
+
init_values = 1.0,
|
| 85 |
+
ffn_layer = "mlp",
|
| 86 |
+
block_chunks = 0,
|
|
|
|
|
|
|
| 87 |
)
|
| 88 |
|
| 89 |
dinov2_vitl14 = vit_large(**vit_kwargs).eval()
|
|
|
|
| 94 |
else:
|
| 95 |
self.cnn = VGG19(**cnn_kwargs)
|
| 96 |
self.amp = amp
|
| 97 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 98 |
if self.amp:
|
| 99 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
| 100 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
| 101 |
+
|
| 102 |
+
|
| 103 |
def train(self, mode: bool = True):
|
| 104 |
return self.cnn.train(mode)
|
| 105 |
+
|
| 106 |
+
def forward(self, x, upsample = False):
|
| 107 |
+
B,C,H,W = x.shape
|
| 108 |
feature_pyramid = self.cnn(x)
|
| 109 |
+
|
| 110 |
if not upsample:
|
| 111 |
with torch.no_grad():
|
| 112 |
if self.dinov2_vitl14[0].device != x.device:
|
| 113 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
| 114 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
| 115 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
del dinov2_features_16
|
| 117 |
feature_pyramid[16] = features_16
|
| 118 |
+
return feature_pyramid
|
third_party/Roma/roma/models/matcher.py
CHANGED
|
@@ -14,9 +14,6 @@ from roma.utils.local_correlation import local_correlation
|
|
| 14 |
from roma.utils.utils import cls_to_flow_refine
|
| 15 |
from roma.utils.kde import kde
|
| 16 |
|
| 17 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
-
|
| 19 |
-
|
| 20 |
class ConvRefiner(nn.Module):
|
| 21 |
def __init__(
|
| 22 |
self,
|
|
@@ -26,29 +23,25 @@ class ConvRefiner(nn.Module):
|
|
| 26 |
dw=False,
|
| 27 |
kernel_size=5,
|
| 28 |
hidden_blocks=3,
|
| 29 |
-
displacement_emb=None,
|
| 30 |
-
displacement_emb_dim=None,
|
| 31 |
-
local_corr_radius=None,
|
| 32 |
-
corr_in_other=None,
|
| 33 |
-
no_im_B_fm=False,
|
| 34 |
-
amp=False,
|
| 35 |
-
concat_logits=False,
|
| 36 |
-
use_bias_block_1=True,
|
| 37 |
-
use_cosine_corr=False,
|
| 38 |
-
disable_local_corr_grad=False,
|
| 39 |
-
is_classifier=False,
|
| 40 |
-
sample_mode="bilinear",
|
| 41 |
-
norm_type=nn.BatchNorm2d,
|
| 42 |
-
bn_momentum=0.1,
|
| 43 |
):
|
| 44 |
super().__init__()
|
| 45 |
self.bn_momentum = bn_momentum
|
| 46 |
self.block1 = self.create_block(
|
| 47 |
-
in_dim,
|
| 48 |
-
hidden_dim,
|
| 49 |
-
dw=dw,
|
| 50 |
-
kernel_size=kernel_size,
|
| 51 |
-
bias=use_bias_block_1,
|
| 52 |
)
|
| 53 |
self.hidden_blocks = nn.Sequential(
|
| 54 |
*[
|
|
@@ -66,7 +59,7 @@ class ConvRefiner(nn.Module):
|
|
| 66 |
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
| 67 |
if displacement_emb:
|
| 68 |
self.has_displacement_emb = True
|
| 69 |
-
self.disp_emb = nn.Conv2d(2,
|
| 70 |
else:
|
| 71 |
self.has_displacement_emb = False
|
| 72 |
self.local_corr_radius = local_corr_radius
|
|
@@ -78,22 +71,16 @@ class ConvRefiner(nn.Module):
|
|
| 78 |
self.disable_local_corr_grad = disable_local_corr_grad
|
| 79 |
self.is_classifier = is_classifier
|
| 80 |
self.sample_mode = sample_mode
|
| 81 |
-
if torch.cuda.
|
| 82 |
-
|
| 83 |
-
self.amp_dtype = torch.bfloat16
|
| 84 |
-
else:
|
| 85 |
-
self.amp_dtype = torch.float16
|
| 86 |
-
else:
|
| 87 |
-
self.amp_dtype = torch.float32
|
| 88 |
-
|
| 89 |
def create_block(
|
| 90 |
self,
|
| 91 |
in_dim,
|
| 92 |
out_dim,
|
| 93 |
dw=False,
|
| 94 |
kernel_size=5,
|
| 95 |
-
bias=True,
|
| 96 |
-
norm_type=nn.BatchNorm2d,
|
| 97 |
):
|
| 98 |
num_groups = 1 if not dw else in_dim
|
| 99 |
if dw:
|
|
@@ -109,56 +96,38 @@ class ConvRefiner(nn.Module):
|
|
| 109 |
groups=num_groups,
|
| 110 |
bias=bias,
|
| 111 |
)
|
| 112 |
-
norm = (
|
| 113 |
-
norm_type(out_dim, momentum=self.bn_momentum)
|
| 114 |
-
if norm_type is nn.BatchNorm2d
|
| 115 |
-
else norm_type(num_channels=out_dim)
|
| 116 |
-
)
|
| 117 |
relu = nn.ReLU(inplace=True)
|
| 118 |
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
| 119 |
return nn.Sequential(conv1, norm, relu, conv2)
|
| 120 |
-
|
| 121 |
-
def forward(self, x, y, flow, scale_factor=1, logits=None):
|
| 122 |
-
b,
|
| 123 |
-
with torch.autocast(
|
| 124 |
with torch.no_grad():
|
| 125 |
-
x_hat = F.grid_sample(
|
| 126 |
-
y,
|
| 127 |
-
flow.permute(0, 2, 3, 1),
|
| 128 |
-
align_corners=False,
|
| 129 |
-
mode=self.sample_mode,
|
| 130 |
-
)
|
| 131 |
if self.has_displacement_emb:
|
| 132 |
im_A_coords = torch.meshgrid(
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
)
|
| 138 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
| 139 |
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
| 140 |
-
in_displacement = flow
|
| 141 |
-
emb_in_displacement = self.disp_emb(
|
| 142 |
-
40 / 32 * scale_factor * in_displacement
|
| 143 |
-
)
|
| 144 |
if self.local_corr_radius:
|
| 145 |
if self.corr_in_other:
|
| 146 |
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
| 147 |
-
local_corr = local_correlation(
|
| 148 |
-
|
| 149 |
-
y,
|
| 150 |
-
local_radius=self.local_corr_radius,
|
| 151 |
-
flow=flow,
|
| 152 |
-
sample_mode=self.sample_mode,
|
| 153 |
-
)
|
| 154 |
else:
|
| 155 |
-
raise NotImplementedError(
|
| 156 |
-
"Local corr in own frame should not be used."
|
| 157 |
-
)
|
| 158 |
if self.no_im_B_fm:
|
| 159 |
x_hat = torch.zeros_like(x)
|
| 160 |
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
| 161 |
-
else:
|
| 162 |
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
| 163 |
else:
|
| 164 |
if self.no_im_B_fm:
|
|
@@ -172,7 +141,6 @@ class ConvRefiner(nn.Module):
|
|
| 172 |
displacement, certainty = d[:, :-1], d[:, -1:]
|
| 173 |
return displacement, certainty
|
| 174 |
|
| 175 |
-
|
| 176 |
class CosKernel(nn.Module): # similar to softmax kernel
|
| 177 |
def __init__(self, T, learn_temperature=False):
|
| 178 |
super().__init__()
|
|
@@ -193,7 +161,6 @@ class CosKernel(nn.Module): # similar to softmax kernel
|
|
| 193 |
K = ((c - 1.0) / T).exp()
|
| 194 |
return K
|
| 195 |
|
| 196 |
-
|
| 197 |
class GP(nn.Module):
|
| 198 |
def __init__(
|
| 199 |
self,
|
|
@@ -207,7 +174,7 @@ class GP(nn.Module):
|
|
| 207 |
only_nearest_neighbour=False,
|
| 208 |
sigma_noise=0.1,
|
| 209 |
no_cov=False,
|
| 210 |
-
predict_features=False,
|
| 211 |
):
|
| 212 |
super().__init__()
|
| 213 |
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
|
@@ -295,9 +262,7 @@ class GP(nn.Module):
|
|
| 295 |
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
| 296 |
if not self.no_cov:
|
| 297 |
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
| 298 |
-
cov_x = rearrange(
|
| 299 |
-
cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
|
| 300 |
-
)
|
| 301 |
local_cov_x = self.get_local_cov(cov_x)
|
| 302 |
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
| 303 |
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
|
@@ -305,22 +270,11 @@ class GP(nn.Module):
|
|
| 305 |
gp_feats = mu_x
|
| 306 |
return gp_feats
|
| 307 |
|
| 308 |
-
|
| 309 |
class Decoder(nn.Module):
|
| 310 |
def __init__(
|
| 311 |
-
self,
|
| 312 |
-
|
| 313 |
-
|
| 314 |
-
proj,
|
| 315 |
-
conv_refiner,
|
| 316 |
-
detach=False,
|
| 317 |
-
scales="all",
|
| 318 |
-
pos_embeddings=None,
|
| 319 |
-
num_refinement_steps_per_scale=1,
|
| 320 |
-
warp_noise_std=0.0,
|
| 321 |
-
displacement_dropout_p=0.0,
|
| 322 |
-
gm_warp_dropout_p=0.0,
|
| 323 |
-
flow_upsample_mode="bilinear",
|
| 324 |
):
|
| 325 |
super().__init__()
|
| 326 |
self.embedding_decoder = embedding_decoder
|
|
@@ -342,14 +296,8 @@ class Decoder(nn.Module):
|
|
| 342 |
self.displacement_dropout_p = displacement_dropout_p
|
| 343 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
| 344 |
self.flow_upsample_mode = flow_upsample_mode
|
| 345 |
-
if torch.cuda.
|
| 346 |
-
|
| 347 |
-
self.amp_dtype = torch.bfloat16
|
| 348 |
-
else:
|
| 349 |
-
self.amp_dtype = torch.float16
|
| 350 |
-
else:
|
| 351 |
-
self.amp_dtype = torch.float32
|
| 352 |
-
|
| 353 |
def get_placeholder_flow(self, b, h, w, device):
|
| 354 |
coarse_coords = torch.meshgrid(
|
| 355 |
(
|
|
@@ -362,8 +310,8 @@ class Decoder(nn.Module):
|
|
| 362 |
].expand(b, h, w, 2)
|
| 363 |
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
| 364 |
return coarse_coords
|
| 365 |
-
|
| 366 |
-
def get_positional_embedding(self, b, h,
|
| 367 |
coarse_coords = torch.meshgrid(
|
| 368 |
(
|
| 369 |
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
@@ -378,29 +326,16 @@ class Decoder(nn.Module):
|
|
| 378 |
coarse_embedded_coords = self.pos_embedding(coarse_coords)
|
| 379 |
return coarse_embedded_coords
|
| 380 |
|
| 381 |
-
def forward(
|
| 382 |
-
self,
|
| 383 |
-
f1,
|
| 384 |
-
f2,
|
| 385 |
-
gt_warp=None,
|
| 386 |
-
gt_prob=None,
|
| 387 |
-
upsample=False,
|
| 388 |
-
flow=None,
|
| 389 |
-
certainty=None,
|
| 390 |
-
scale_factor=1,
|
| 391 |
-
):
|
| 392 |
coarse_scales = self.embedding_decoder.scales()
|
| 393 |
-
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
| 394 |
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
| 395 |
h, w = sizes[1]
|
| 396 |
b = f1[1].shape[0]
|
| 397 |
device = f1[1].device
|
| 398 |
coarsest_scale = int(all_scales[0])
|
| 399 |
old_stuff = torch.zeros(
|
| 400 |
-
b,
|
| 401 |
-
self.embedding_decoder.hidden_dim,
|
| 402 |
-
*sizes[coarsest_scale],
|
| 403 |
-
device=f1[coarsest_scale].device,
|
| 404 |
)
|
| 405 |
corresps = {}
|
| 406 |
if not upsample:
|
|
@@ -408,24 +343,24 @@ class Decoder(nn.Module):
|
|
| 408 |
certainty = 0.0
|
| 409 |
else:
|
| 410 |
flow = F.interpolate(
|
| 411 |
-
|
| 412 |
-
|
| 413 |
-
|
| 414 |
-
|
| 415 |
-
|
| 416 |
certainty = F.interpolate(
|
| 417 |
-
|
| 418 |
-
|
| 419 |
-
|
| 420 |
-
|
| 421 |
-
|
| 422 |
displacement = 0.0
|
| 423 |
for new_scale in all_scales:
|
| 424 |
ins = int(new_scale)
|
| 425 |
corresps[ins] = {}
|
| 426 |
f1_s, f2_s = f1[ins], f2[ins]
|
| 427 |
if new_scale in self.proj:
|
| 428 |
-
with torch.autocast(
|
| 429 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
| 430 |
|
| 431 |
if ins in coarse_scales:
|
|
@@ -436,59 +371,32 @@ class Decoder(nn.Module):
|
|
| 436 |
gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
|
| 437 |
gp_posterior, f1_s, old_stuff, new_scale
|
| 438 |
)
|
| 439 |
-
|
| 440 |
if self.embedding_decoder.is_classifier:
|
| 441 |
flow = cls_to_flow_refine(
|
| 442 |
gm_warp_or_cls,
|
| 443 |
-
).permute(0,
|
| 444 |
-
corresps[ins].update(
|
| 445 |
-
{
|
| 446 |
-
"gm_cls": gm_warp_or_cls,
|
| 447 |
-
"gm_certainty": certainty,
|
| 448 |
-
}
|
| 449 |
-
) if self.training else None
|
| 450 |
else:
|
| 451 |
-
corresps[ins].update(
|
| 452 |
-
{
|
| 453 |
-
"gm_flow": gm_warp_or_cls,
|
| 454 |
-
"gm_certainty": certainty,
|
| 455 |
-
}
|
| 456 |
-
) if self.training else None
|
| 457 |
flow = gm_warp_or_cls.detach()
|
| 458 |
-
|
| 459 |
if new_scale in self.conv_refiner:
|
| 460 |
-
corresps[ins].update(
|
| 461 |
-
{"flow_pre_delta": flow}
|
| 462 |
-
) if self.training else None
|
| 463 |
delta_flow, delta_certainty = self.conv_refiner[new_scale](
|
| 464 |
-
f1_s,
|
| 465 |
-
|
| 466 |
-
|
| 467 |
-
|
| 468 |
-
|
| 469 |
-
)
|
| 470 |
-
corresps[ins].update(
|
| 471 |
-
{
|
| 472 |
-
"delta_flow": delta_flow,
|
| 473 |
-
}
|
| 474 |
-
) if self.training else None
|
| 475 |
-
displacement = ins * torch.stack(
|
| 476 |
-
(
|
| 477 |
-
delta_flow[:, 0].float() / (self.refine_init * w),
|
| 478 |
-
delta_flow[:, 1].float() / (self.refine_init * h),
|
| 479 |
-
),
|
| 480 |
-
dim=1,
|
| 481 |
-
)
|
| 482 |
flow = flow + displacement
|
| 483 |
certainty = (
|
| 484 |
certainty + delta_certainty
|
| 485 |
) # predict both certainty and displacement
|
| 486 |
-
corresps[ins].update(
|
| 487 |
-
|
| 488 |
-
|
| 489 |
-
|
| 490 |
-
}
|
| 491 |
-
)
|
| 492 |
if new_scale != "1":
|
| 493 |
flow = F.interpolate(
|
| 494 |
flow,
|
|
@@ -503,7 +411,7 @@ class Decoder(nn.Module):
|
|
| 503 |
if self.detach:
|
| 504 |
flow = flow.detach()
|
| 505 |
certainty = certainty.detach()
|
| 506 |
-
#
|
| 507 |
return corresps
|
| 508 |
|
| 509 |
|
|
@@ -514,11 +422,11 @@ class RegressionMatcher(nn.Module):
|
|
| 514 |
decoder,
|
| 515 |
h=448,
|
| 516 |
w=448,
|
| 517 |
-
sample_mode="threshold",
|
| 518 |
-
upsample_preds=False,
|
| 519 |
-
symmetric=False,
|
| 520 |
-
name=None,
|
| 521 |
-
attenuate_cert=None,
|
| 522 |
):
|
| 523 |
super().__init__()
|
| 524 |
self.attenuate_cert = attenuate_cert
|
|
@@ -530,26 +438,24 @@ class RegressionMatcher(nn.Module):
|
|
| 530 |
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
| 531 |
self.sample_mode = sample_mode
|
| 532 |
self.upsample_preds = upsample_preds
|
| 533 |
-
self.upsample_res = (14
|
| 534 |
self.symmetric = symmetric
|
| 535 |
self.sample_thresh = 0.05
|
| 536 |
-
|
| 537 |
def get_output_resolution(self):
|
| 538 |
if not self.upsample_preds:
|
| 539 |
return self.h_resized, self.w_resized
|
| 540 |
else:
|
| 541 |
return self.upsample_res
|
| 542 |
-
|
| 543 |
-
def extract_backbone_features(self, batch, batched=True, upsample=False):
|
| 544 |
x_q = batch["im_A"]
|
| 545 |
x_s = batch["im_B"]
|
| 546 |
if batched:
|
| 547 |
-
X = torch.cat((x_q, x_s), dim=0)
|
| 548 |
-
feature_pyramid = self.encoder(X, upsample=upsample)
|
| 549 |
else:
|
| 550 |
-
feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder(
|
| 551 |
-
x_s, upsample=upsample
|
| 552 |
-
)
|
| 553 |
return feature_pyramid
|
| 554 |
|
| 555 |
def sample(
|
|
@@ -567,28 +473,22 @@ class RegressionMatcher(nn.Module):
|
|
| 567 |
certainty.reshape(-1),
|
| 568 |
)
|
| 569 |
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
| 570 |
-
good_samples = torch.multinomial(
|
| 571 |
-
|
| 572 |
-
|
| 573 |
-
replacement=False,
|
| 574 |
-
)
|
| 575 |
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
| 576 |
if "balanced" not in self.sample_mode:
|
| 577 |
return good_matches, good_certainty
|
| 578 |
density = kde(good_matches, std=0.1)
|
| 579 |
-
p = 1 / (density
|
| 580 |
-
p[
|
| 581 |
-
|
| 582 |
-
|
| 583 |
-
|
| 584 |
-
p, num_samples=min(num, len(good_certainty)), replacement=False
|
| 585 |
-
)
|
| 586 |
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
| 587 |
|
| 588 |
-
def forward(self, batch, batched=True, upsample=False, scale_factor=1):
|
| 589 |
-
feature_pyramid = self.extract_backbone_features(
|
| 590 |
-
batch, batched=batched, upsample=upsample
|
| 591 |
-
)
|
| 592 |
if batched:
|
| 593 |
f_q_pyramid = {
|
| 594 |
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
|
@@ -598,42 +498,32 @@ class RegressionMatcher(nn.Module):
|
|
| 598 |
}
|
| 599 |
else:
|
| 600 |
f_q_pyramid, f_s_pyramid = feature_pyramid
|
| 601 |
-
corresps = self.decoder(
|
| 602 |
-
|
| 603 |
-
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
)
|
| 608 |
-
|
| 609 |
return corresps
|
| 610 |
|
| 611 |
-
def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
|
| 612 |
-
feature_pyramid = self.extract_backbone_features(
|
| 613 |
-
batch, batched=batched, upsample=upsample
|
| 614 |
-
)
|
| 615 |
f_q_pyramid = feature_pyramid
|
| 616 |
f_s_pyramid = {
|
| 617 |
-
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
|
| 618 |
for scale, f_scale in feature_pyramid.items()
|
| 619 |
}
|
| 620 |
-
corresps = self.decoder(
|
| 621 |
-
|
| 622 |
-
|
| 623 |
-
|
| 624 |
-
|
| 625 |
-
scale_factor=scale_factor,
|
| 626 |
-
)
|
| 627 |
return corresps
|
| 628 |
-
|
| 629 |
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
| 630 |
-
kpts_A, kpts_B = matches[
|
| 631 |
-
kpts_A = torch.stack(
|
| 632 |
-
|
| 633 |
-
)
|
| 634 |
-
kpts_B = torch.stack(
|
| 635 |
-
(W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1
|
| 636 |
-
)
|
| 637 |
return kpts_A, kpts_B
|
| 638 |
|
| 639 |
def match(
|
|
@@ -642,12 +532,11 @@ class RegressionMatcher(nn.Module):
|
|
| 642 |
im_B_path,
|
| 643 |
*args,
|
| 644 |
batched=False,
|
| 645 |
-
device=None,
|
| 646 |
):
|
| 647 |
if device is None:
|
| 648 |
-
device = torch.device(
|
| 649 |
from PIL import Image
|
| 650 |
-
|
| 651 |
if isinstance(im_A_path, (str, os.PathLike)):
|
| 652 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
| 653 |
else:
|
|
@@ -663,9 +552,9 @@ class RegressionMatcher(nn.Module):
|
|
| 663 |
# Get images in good format
|
| 664 |
ws = self.w_resized
|
| 665 |
hs = self.h_resized
|
| 666 |
-
|
| 667 |
test_transform = get_tuple_transform_ops(
|
| 668 |
-
resize=(hs, ws), normalize=True, clahe=False
|
| 669 |
)
|
| 670 |
im_A, im_B = test_transform((im_A, im_B))
|
| 671 |
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
|
@@ -675,32 +564,25 @@ class RegressionMatcher(nn.Module):
|
|
| 675 |
assert w == w2 and h == h2, "For batched images we assume same size"
|
| 676 |
batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
|
| 677 |
if h != self.h_resized or self.w_resized != w:
|
| 678 |
-
warn(
|
| 679 |
-
"Model resolution and batch resolution differ, may produce unexpected results"
|
| 680 |
-
)
|
| 681 |
hs, ws = h, w
|
| 682 |
finest_scale = 1
|
| 683 |
# Run matcher
|
| 684 |
if symmetric:
|
| 685 |
-
corresps
|
| 686 |
else:
|
| 687 |
-
corresps = self.forward(batch, batched=True)
|
| 688 |
|
| 689 |
if self.upsample_preds:
|
| 690 |
hs, ws = self.upsample_res
|
| 691 |
-
|
| 692 |
if self.attenuate_cert:
|
| 693 |
low_res_certainty = F.interpolate(
|
| 694 |
-
|
| 695 |
-
size=(hs, ws),
|
| 696 |
-
align_corners=False,
|
| 697 |
-
mode="bilinear",
|
| 698 |
)
|
| 699 |
cert_clamp = 0
|
| 700 |
factor = 0.5
|
| 701 |
-
low_res_certainty = (
|
| 702 |
-
factor * low_res_certainty * (low_res_certainty < cert_clamp)
|
| 703 |
-
)
|
| 704 |
|
| 705 |
if self.upsample_preds:
|
| 706 |
finest_corresps = corresps[finest_scale]
|
|
@@ -711,38 +593,30 @@ class RegressionMatcher(nn.Module):
|
|
| 711 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
| 712 |
im_A, im_B = test_transform((im_A, im_B))
|
| 713 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
| 714 |
-
scale_factor = math.sqrt(
|
| 715 |
-
self.upsample_res[0]
|
| 716 |
-
* self.upsample_res[1]
|
| 717 |
-
/ (self.w_resized * self.h_resized)
|
| 718 |
-
)
|
| 719 |
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
| 720 |
if symmetric:
|
| 721 |
-
corresps = self.forward_symmetric(
|
| 722 |
-
batch, upsample=True, batched=True, scale_factor=scale_factor
|
| 723 |
-
)
|
| 724 |
else:
|
| 725 |
-
corresps = self.forward(
|
| 726 |
-
|
| 727 |
-
|
| 728 |
-
|
| 729 |
-
im_A_to_im_B = corresps[finest_scale]["flow"]
|
| 730 |
-
certainty = corresps[finest_scale]["certainty"] - (
|
| 731 |
-
low_res_certainty if self.attenuate_cert else 0
|
| 732 |
-
)
|
| 733 |
if finest_scale != 1:
|
| 734 |
im_A_to_im_B = F.interpolate(
|
| 735 |
-
|
| 736 |
)
|
| 737 |
certainty = F.interpolate(
|
| 738 |
-
|
|
|
|
|
|
|
|
|
|
| 739 |
)
|
| 740 |
-
im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
|
| 741 |
# Create im_A meshgrid
|
| 742 |
im_A_coords = torch.meshgrid(
|
| 743 |
(
|
| 744 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
| 745 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
| 746 |
)
|
| 747 |
)
|
| 748 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
@@ -751,21 +625,25 @@ class RegressionMatcher(nn.Module):
|
|
| 751 |
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
| 752 |
if (im_A_to_im_B.abs() > 1).any() and True:
|
| 753 |
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
| 754 |
-
certainty[wrong[:,
|
| 755 |
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
| 756 |
if symmetric:
|
| 757 |
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
| 758 |
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
| 759 |
im_B_coords = im_A_coords
|
| 760 |
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
| 761 |
-
warp = torch.cat((q_warp, s_warp),
|
| 762 |
certainty = torch.cat(certainty.chunk(2), dim=3)
|
| 763 |
else:
|
| 764 |
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
| 765 |
if batched:
|
| 766 |
-
return (
|
|
|
|
|
|
|
|
|
|
| 767 |
else:
|
| 768 |
return (
|
| 769 |
warp[0],
|
| 770 |
certainty[0, 0],
|
| 771 |
)
|
|
|
|
|
|
| 14 |
from roma.utils.utils import cls_to_flow_refine
|
| 15 |
from roma.utils.kde import kde
|
| 16 |
|
|
|
|
|
|
|
|
|
|
| 17 |
class ConvRefiner(nn.Module):
|
| 18 |
def __init__(
|
| 19 |
self,
|
|
|
|
| 23 |
dw=False,
|
| 24 |
kernel_size=5,
|
| 25 |
hidden_blocks=3,
|
| 26 |
+
displacement_emb = None,
|
| 27 |
+
displacement_emb_dim = None,
|
| 28 |
+
local_corr_radius = None,
|
| 29 |
+
corr_in_other = None,
|
| 30 |
+
no_im_B_fm = False,
|
| 31 |
+
amp = False,
|
| 32 |
+
concat_logits = False,
|
| 33 |
+
use_bias_block_1 = True,
|
| 34 |
+
use_cosine_corr = False,
|
| 35 |
+
disable_local_corr_grad = False,
|
| 36 |
+
is_classifier = False,
|
| 37 |
+
sample_mode = "bilinear",
|
| 38 |
+
norm_type = nn.BatchNorm2d,
|
| 39 |
+
bn_momentum = 0.1,
|
| 40 |
):
|
| 41 |
super().__init__()
|
| 42 |
self.bn_momentum = bn_momentum
|
| 43 |
self.block1 = self.create_block(
|
| 44 |
+
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 45 |
)
|
| 46 |
self.hidden_blocks = nn.Sequential(
|
| 47 |
*[
|
|
|
|
| 59 |
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
| 60 |
if displacement_emb:
|
| 61 |
self.has_displacement_emb = True
|
| 62 |
+
self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
|
| 63 |
else:
|
| 64 |
self.has_displacement_emb = False
|
| 65 |
self.local_corr_radius = local_corr_radius
|
|
|
|
| 71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
| 72 |
self.is_classifier = is_classifier
|
| 73 |
self.sample_mode = sample_mode
|
| 74 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 75 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
def create_block(
|
| 77 |
self,
|
| 78 |
in_dim,
|
| 79 |
out_dim,
|
| 80 |
dw=False,
|
| 81 |
kernel_size=5,
|
| 82 |
+
bias = True,
|
| 83 |
+
norm_type = nn.BatchNorm2d,
|
| 84 |
):
|
| 85 |
num_groups = 1 if not dw else in_dim
|
| 86 |
if dw:
|
|
|
|
| 96 |
groups=num_groups,
|
| 97 |
bias=bias,
|
| 98 |
)
|
| 99 |
+
norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 100 |
relu = nn.ReLU(inplace=True)
|
| 101 |
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
| 102 |
return nn.Sequential(conv1, norm, relu, conv2)
|
| 103 |
+
|
| 104 |
+
def forward(self, x, y, flow, scale_factor = 1, logits = None):
|
| 105 |
+
b,c,hs,ws = x.shape
|
| 106 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
| 107 |
with torch.no_grad():
|
| 108 |
+
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 109 |
if self.has_displacement_emb:
|
| 110 |
im_A_coords = torch.meshgrid(
|
| 111 |
+
(
|
| 112 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
| 113 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
| 114 |
+
)
|
| 115 |
)
|
| 116 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
| 117 |
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
| 118 |
+
in_displacement = flow-im_A_coords
|
| 119 |
+
emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
|
|
|
|
|
|
|
| 120 |
if self.local_corr_radius:
|
| 121 |
if self.corr_in_other:
|
| 122 |
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
| 123 |
+
local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
|
| 124 |
+
sample_mode = self.sample_mode)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
else:
|
| 126 |
+
raise NotImplementedError("Local corr in own frame should not be used.")
|
|
|
|
|
|
|
| 127 |
if self.no_im_B_fm:
|
| 128 |
x_hat = torch.zeros_like(x)
|
| 129 |
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
| 130 |
+
else:
|
| 131 |
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
| 132 |
else:
|
| 133 |
if self.no_im_B_fm:
|
|
|
|
| 141 |
displacement, certainty = d[:, :-1], d[:, -1:]
|
| 142 |
return displacement, certainty
|
| 143 |
|
|
|
|
| 144 |
class CosKernel(nn.Module): # similar to softmax kernel
|
| 145 |
def __init__(self, T, learn_temperature=False):
|
| 146 |
super().__init__()
|
|
|
|
| 161 |
K = ((c - 1.0) / T).exp()
|
| 162 |
return K
|
| 163 |
|
|
|
|
| 164 |
class GP(nn.Module):
|
| 165 |
def __init__(
|
| 166 |
self,
|
|
|
|
| 174 |
only_nearest_neighbour=False,
|
| 175 |
sigma_noise=0.1,
|
| 176 |
no_cov=False,
|
| 177 |
+
predict_features = False,
|
| 178 |
):
|
| 179 |
super().__init__()
|
| 180 |
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
|
|
|
| 262 |
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
| 263 |
if not self.no_cov:
|
| 264 |
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
| 265 |
+
cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
|
|
|
|
|
|
|
| 266 |
local_cov_x = self.get_local_cov(cov_x)
|
| 267 |
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
| 268 |
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
|
|
|
| 270 |
gp_feats = mu_x
|
| 271 |
return gp_feats
|
| 272 |
|
|
|
|
| 273 |
class Decoder(nn.Module):
|
| 274 |
def __init__(
|
| 275 |
+
self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
|
| 276 |
+
num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
|
| 277 |
+
flow_upsample_mode = "bilinear"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 278 |
):
|
| 279 |
super().__init__()
|
| 280 |
self.embedding_decoder = embedding_decoder
|
|
|
|
| 296 |
self.displacement_dropout_p = displacement_dropout_p
|
| 297 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
| 298 |
self.flow_upsample_mode = flow_upsample_mode
|
| 299 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
| 300 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 301 |
def get_placeholder_flow(self, b, h, w, device):
|
| 302 |
coarse_coords = torch.meshgrid(
|
| 303 |
(
|
|
|
|
| 310 |
].expand(b, h, w, 2)
|
| 311 |
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
| 312 |
return coarse_coords
|
| 313 |
+
|
| 314 |
+
def get_positional_embedding(self, b, h ,w, device):
|
| 315 |
coarse_coords = torch.meshgrid(
|
| 316 |
(
|
| 317 |
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
|
|
| 326 |
coarse_embedded_coords = self.pos_embedding(coarse_coords)
|
| 327 |
return coarse_embedded_coords
|
| 328 |
|
| 329 |
+
def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 330 |
coarse_scales = self.embedding_decoder.scales()
|
| 331 |
+
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
| 332 |
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
| 333 |
h, w = sizes[1]
|
| 334 |
b = f1[1].shape[0]
|
| 335 |
device = f1[1].device
|
| 336 |
coarsest_scale = int(all_scales[0])
|
| 337 |
old_stuff = torch.zeros(
|
| 338 |
+
b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
|
|
|
|
|
|
|
|
|
|
| 339 |
)
|
| 340 |
corresps = {}
|
| 341 |
if not upsample:
|
|
|
|
| 343 |
certainty = 0.0
|
| 344 |
else:
|
| 345 |
flow = F.interpolate(
|
| 346 |
+
flow,
|
| 347 |
+
size=sizes[coarsest_scale],
|
| 348 |
+
align_corners=False,
|
| 349 |
+
mode="bilinear",
|
| 350 |
+
)
|
| 351 |
certainty = F.interpolate(
|
| 352 |
+
certainty,
|
| 353 |
+
size=sizes[coarsest_scale],
|
| 354 |
+
align_corners=False,
|
| 355 |
+
mode="bilinear",
|
| 356 |
+
)
|
| 357 |
displacement = 0.0
|
| 358 |
for new_scale in all_scales:
|
| 359 |
ins = int(new_scale)
|
| 360 |
corresps[ins] = {}
|
| 361 |
f1_s, f2_s = f1[ins], f2[ins]
|
| 362 |
if new_scale in self.proj:
|
| 363 |
+
with torch.autocast("cuda", self.amp_dtype):
|
| 364 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
| 365 |
|
| 366 |
if ins in coarse_scales:
|
|
|
|
| 371 |
gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
|
| 372 |
gp_posterior, f1_s, old_stuff, new_scale
|
| 373 |
)
|
| 374 |
+
|
| 375 |
if self.embedding_decoder.is_classifier:
|
| 376 |
flow = cls_to_flow_refine(
|
| 377 |
gm_warp_or_cls,
|
| 378 |
+
).permute(0,3,1,2)
|
| 379 |
+
corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
else:
|
| 381 |
+
corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
flow = gm_warp_or_cls.detach()
|
| 383 |
+
|
| 384 |
if new_scale in self.conv_refiner:
|
| 385 |
+
corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
|
|
|
|
|
|
|
| 386 |
delta_flow, delta_certainty = self.conv_refiner[new_scale](
|
| 387 |
+
f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
|
| 388 |
+
)
|
| 389 |
+
corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
|
| 390 |
+
displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
|
| 391 |
+
delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 392 |
flow = flow + displacement
|
| 393 |
certainty = (
|
| 394 |
certainty + delta_certainty
|
| 395 |
) # predict both certainty and displacement
|
| 396 |
+
corresps[ins].update({
|
| 397 |
+
"certainty": certainty,
|
| 398 |
+
"flow": flow,
|
| 399 |
+
})
|
|
|
|
|
|
|
| 400 |
if new_scale != "1":
|
| 401 |
flow = F.interpolate(
|
| 402 |
flow,
|
|
|
|
| 411 |
if self.detach:
|
| 412 |
flow = flow.detach()
|
| 413 |
certainty = certainty.detach()
|
| 414 |
+
#torch.cuda.empty_cache()
|
| 415 |
return corresps
|
| 416 |
|
| 417 |
|
|
|
|
| 422 |
decoder,
|
| 423 |
h=448,
|
| 424 |
w=448,
|
| 425 |
+
sample_mode = "threshold",
|
| 426 |
+
upsample_preds = False,
|
| 427 |
+
symmetric = False,
|
| 428 |
+
name = None,
|
| 429 |
+
attenuate_cert = None,
|
| 430 |
):
|
| 431 |
super().__init__()
|
| 432 |
self.attenuate_cert = attenuate_cert
|
|
|
|
| 438 |
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
| 439 |
self.sample_mode = sample_mode
|
| 440 |
self.upsample_preds = upsample_preds
|
| 441 |
+
self.upsample_res = (14*16*6, 14*16*6)
|
| 442 |
self.symmetric = symmetric
|
| 443 |
self.sample_thresh = 0.05
|
| 444 |
+
|
| 445 |
def get_output_resolution(self):
|
| 446 |
if not self.upsample_preds:
|
| 447 |
return self.h_resized, self.w_resized
|
| 448 |
else:
|
| 449 |
return self.upsample_res
|
| 450 |
+
|
| 451 |
+
def extract_backbone_features(self, batch, batched = True, upsample = False):
|
| 452 |
x_q = batch["im_A"]
|
| 453 |
x_s = batch["im_B"]
|
| 454 |
if batched:
|
| 455 |
+
X = torch.cat((x_q, x_s), dim = 0)
|
| 456 |
+
feature_pyramid = self.encoder(X, upsample = upsample)
|
| 457 |
else:
|
| 458 |
+
feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
|
|
|
|
|
|
|
| 459 |
return feature_pyramid
|
| 460 |
|
| 461 |
def sample(
|
|
|
|
| 473 |
certainty.reshape(-1),
|
| 474 |
)
|
| 475 |
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
| 476 |
+
good_samples = torch.multinomial(certainty,
|
| 477 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
| 478 |
+
replacement=False)
|
|
|
|
|
|
|
| 479 |
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
| 480 |
if "balanced" not in self.sample_mode:
|
| 481 |
return good_matches, good_certainty
|
| 482 |
density = kde(good_matches, std=0.1)
|
| 483 |
+
p = 1 / (density+1)
|
| 484 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
| 485 |
+
balanced_samples = torch.multinomial(p,
|
| 486 |
+
num_samples = min(num,len(good_certainty)),
|
| 487 |
+
replacement=False)
|
|
|
|
|
|
|
| 488 |
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
| 489 |
|
| 490 |
+
def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
|
| 491 |
+
feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
|
|
|
|
|
|
|
| 492 |
if batched:
|
| 493 |
f_q_pyramid = {
|
| 494 |
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
|
|
|
| 498 |
}
|
| 499 |
else:
|
| 500 |
f_q_pyramid, f_s_pyramid = feature_pyramid
|
| 501 |
+
corresps = self.decoder(f_q_pyramid,
|
| 502 |
+
f_s_pyramid,
|
| 503 |
+
upsample = upsample,
|
| 504 |
+
**(batch["corresps"] if "corresps" in batch else {}),
|
| 505 |
+
scale_factor=scale_factor)
|
| 506 |
+
|
|
|
|
|
|
|
| 507 |
return corresps
|
| 508 |
|
| 509 |
+
def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
|
| 510 |
+
feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
|
|
|
|
|
|
|
| 511 |
f_q_pyramid = feature_pyramid
|
| 512 |
f_s_pyramid = {
|
| 513 |
+
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
|
| 514 |
for scale, f_scale in feature_pyramid.items()
|
| 515 |
}
|
| 516 |
+
corresps = self.decoder(f_q_pyramid,
|
| 517 |
+
f_s_pyramid,
|
| 518 |
+
upsample = upsample,
|
| 519 |
+
**(batch["corresps"] if "corresps" in batch else {}),
|
| 520 |
+
scale_factor=scale_factor)
|
|
|
|
|
|
|
| 521 |
return corresps
|
| 522 |
+
|
| 523 |
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
| 524 |
+
kpts_A, kpts_B = matches[...,:2], matches[...,2:]
|
| 525 |
+
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
| 526 |
+
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 527 |
return kpts_A, kpts_B
|
| 528 |
|
| 529 |
def match(
|
|
|
|
| 532 |
im_B_path,
|
| 533 |
*args,
|
| 534 |
batched=False,
|
| 535 |
+
device = None,
|
| 536 |
):
|
| 537 |
if device is None:
|
| 538 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 539 |
from PIL import Image
|
|
|
|
| 540 |
if isinstance(im_A_path, (str, os.PathLike)):
|
| 541 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
| 542 |
else:
|
|
|
|
| 552 |
# Get images in good format
|
| 553 |
ws = self.w_resized
|
| 554 |
hs = self.h_resized
|
| 555 |
+
|
| 556 |
test_transform = get_tuple_transform_ops(
|
| 557 |
+
resize=(hs, ws), normalize=True, clahe = False
|
| 558 |
)
|
| 559 |
im_A, im_B = test_transform((im_A, im_B))
|
| 560 |
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
|
|
|
| 564 |
assert w == w2 and h == h2, "For batched images we assume same size"
|
| 565 |
batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
|
| 566 |
if h != self.h_resized or self.w_resized != w:
|
| 567 |
+
warn("Model resolution and batch resolution differ, may produce unexpected results")
|
|
|
|
|
|
|
| 568 |
hs, ws = h, w
|
| 569 |
finest_scale = 1
|
| 570 |
# Run matcher
|
| 571 |
if symmetric:
|
| 572 |
+
corresps = self.forward_symmetric(batch)
|
| 573 |
else:
|
| 574 |
+
corresps = self.forward(batch, batched = True)
|
| 575 |
|
| 576 |
if self.upsample_preds:
|
| 577 |
hs, ws = self.upsample_res
|
| 578 |
+
|
| 579 |
if self.attenuate_cert:
|
| 580 |
low_res_certainty = F.interpolate(
|
| 581 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
|
|
|
|
|
|
|
| 582 |
)
|
| 583 |
cert_clamp = 0
|
| 584 |
factor = 0.5
|
| 585 |
+
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
|
|
|
|
|
|
| 586 |
|
| 587 |
if self.upsample_preds:
|
| 588 |
finest_corresps = corresps[finest_scale]
|
|
|
|
| 593 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
| 594 |
im_A, im_B = test_transform((im_A, im_B))
|
| 595 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
| 596 |
+
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
|
|
|
|
|
|
|
|
|
|
|
|
| 597 |
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
| 598 |
if symmetric:
|
| 599 |
+
corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
|
|
|
|
|
|
|
| 600 |
else:
|
| 601 |
+
corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
|
| 602 |
+
|
| 603 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
| 604 |
+
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 605 |
if finest_scale != 1:
|
| 606 |
im_A_to_im_B = F.interpolate(
|
| 607 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
| 608 |
)
|
| 609 |
certainty = F.interpolate(
|
| 610 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
| 611 |
+
)
|
| 612 |
+
im_A_to_im_B = im_A_to_im_B.permute(
|
| 613 |
+
0, 2, 3, 1
|
| 614 |
)
|
|
|
|
| 615 |
# Create im_A meshgrid
|
| 616 |
im_A_coords = torch.meshgrid(
|
| 617 |
(
|
| 618 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
| 619 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
| 620 |
)
|
| 621 |
)
|
| 622 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
|
| 625 |
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
| 626 |
if (im_A_to_im_B.abs() > 1).any() and True:
|
| 627 |
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
| 628 |
+
certainty[wrong[:,None]] = 0
|
| 629 |
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
| 630 |
if symmetric:
|
| 631 |
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
| 632 |
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
| 633 |
im_B_coords = im_A_coords
|
| 634 |
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
| 635 |
+
warp = torch.cat((q_warp, s_warp),dim=2)
|
| 636 |
certainty = torch.cat(certainty.chunk(2), dim=3)
|
| 637 |
else:
|
| 638 |
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
| 639 |
if batched:
|
| 640 |
+
return (
|
| 641 |
+
warp,
|
| 642 |
+
certainty[:, 0]
|
| 643 |
+
)
|
| 644 |
else:
|
| 645 |
return (
|
| 646 |
warp[0],
|
| 647 |
certainty[0, 0],
|
| 648 |
)
|
| 649 |
+
|