Spaces:
Runtime error
Runtime error
Vincentqyw
commited on
Commit
·
f517bbf
1
Parent(s):
8af5ecd
fix: roma cpu
Browse files
app.py
CHANGED
|
@@ -86,8 +86,7 @@ def ui_reset_state(
|
|
| 86 |
|
| 87 |
|
| 88 |
def run(config):
|
| 89 |
-
with gr.Blocks(css="footer {visibility: hidden}"
|
| 90 |
-
) as app:
|
| 91 |
gr.Markdown(
|
| 92 |
"""
|
| 93 |
<p align="center">
|
|
|
|
| 86 |
|
| 87 |
|
| 88 |
def run(config):
|
| 89 |
+
with gr.Blocks(css="footer {visibility: hidden}") as app:
|
|
|
|
| 90 |
gr.Markdown(
|
| 91 |
"""
|
| 92 |
<p align="center">
|
third_party/Roma/roma/models/encoders.py
CHANGED
|
@@ -6,6 +6,8 @@ import torch.nn.functional as F
|
|
| 6 |
import torchvision.models as tvm
|
| 7 |
import gc
|
| 8 |
|
|
|
|
|
|
|
| 9 |
|
| 10 |
class ResNet50(nn.Module):
|
| 11 |
def __init__(
|
|
@@ -47,7 +49,7 @@ class ResNet50(nn.Module):
|
|
| 47 |
self.amp_dtype = torch.float32
|
| 48 |
|
| 49 |
def forward(self, x, **kwargs):
|
| 50 |
-
with torch.autocast(
|
| 51 |
net = self.net
|
| 52 |
feats = {1: x}
|
| 53 |
x = net.conv1(x)
|
|
@@ -90,7 +92,7 @@ class VGG19(nn.Module):
|
|
| 90 |
self.amp_dtype = torch.float32
|
| 91 |
|
| 92 |
def forward(self, x, **kwargs):
|
| 93 |
-
with torch.autocast(
|
| 94 |
feats = {}
|
| 95 |
scale = 1
|
| 96 |
for layer in self.layers:
|
|
|
|
| 6 |
import torchvision.models as tvm
|
| 7 |
import gc
|
| 8 |
|
| 9 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 10 |
+
|
| 11 |
|
| 12 |
class ResNet50(nn.Module):
|
| 13 |
def __init__(
|
|
|
|
| 49 |
self.amp_dtype = torch.float32
|
| 50 |
|
| 51 |
def forward(self, x, **kwargs):
|
| 52 |
+
with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
|
| 53 |
net = self.net
|
| 54 |
feats = {1: x}
|
| 55 |
x = net.conv1(x)
|
|
|
|
| 92 |
self.amp_dtype = torch.float32
|
| 93 |
|
| 94 |
def forward(self, x, **kwargs):
|
| 95 |
+
with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
|
| 96 |
feats = {}
|
| 97 |
scale = 1
|
| 98 |
for layer in self.layers:
|
third_party/Roma/roma/models/matcher.py
CHANGED
|
@@ -14,6 +14,8 @@ from roma.utils.local_correlation import local_correlation
|
|
| 14 |
from roma.utils.utils import cls_to_flow_refine
|
| 15 |
from roma.utils.kde import kde
|
| 16 |
|
|
|
|
|
|
|
| 17 |
|
| 18 |
class ConvRefiner(nn.Module):
|
| 19 |
def __init__(
|
|
@@ -118,7 +120,7 @@ class ConvRefiner(nn.Module):
|
|
| 118 |
|
| 119 |
def forward(self, x, y, flow, scale_factor=1, logits=None):
|
| 120 |
b, c, hs, ws = x.shape
|
| 121 |
-
with torch.autocast(
|
| 122 |
with torch.no_grad():
|
| 123 |
x_hat = F.grid_sample(
|
| 124 |
y,
|
|
@@ -129,8 +131,8 @@ class ConvRefiner(nn.Module):
|
|
| 129 |
if self.has_displacement_emb:
|
| 130 |
im_A_coords = torch.meshgrid(
|
| 131 |
(
|
| 132 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
| 133 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
| 134 |
)
|
| 135 |
)
|
| 136 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
@@ -423,7 +425,7 @@ class Decoder(nn.Module):
|
|
| 423 |
corresps[ins] = {}
|
| 424 |
f1_s, f2_s = f1[ins], f2[ins]
|
| 425 |
if new_scale in self.proj:
|
| 426 |
-
with torch.autocast(
|
| 427 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
| 428 |
|
| 429 |
if ins in coarse_scales:
|
|
@@ -643,7 +645,7 @@ class RegressionMatcher(nn.Module):
|
|
| 643 |
device=None,
|
| 644 |
):
|
| 645 |
if device is None:
|
| 646 |
-
device = torch.device(
|
| 647 |
from PIL import Image
|
| 648 |
|
| 649 |
if isinstance(im_A_path, (str, os.PathLike)):
|
|
@@ -739,8 +741,8 @@ class RegressionMatcher(nn.Module):
|
|
| 739 |
# Create im_A meshgrid
|
| 740 |
im_A_coords = torch.meshgrid(
|
| 741 |
(
|
| 742 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
| 743 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
| 744 |
)
|
| 745 |
)
|
| 746 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
|
| 14 |
from roma.utils.utils import cls_to_flow_refine
|
| 15 |
from roma.utils.kde import kde
|
| 16 |
|
| 17 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 18 |
+
|
| 19 |
|
| 20 |
class ConvRefiner(nn.Module):
|
| 21 |
def __init__(
|
|
|
|
| 120 |
|
| 121 |
def forward(self, x, y, flow, scale_factor=1, logits=None):
|
| 122 |
b, c, hs, ws = x.shape
|
| 123 |
+
with torch.autocast(device, enabled=self.amp, dtype=self.amp_dtype):
|
| 124 |
with torch.no_grad():
|
| 125 |
x_hat = F.grid_sample(
|
| 126 |
y,
|
|
|
|
| 131 |
if self.has_displacement_emb:
|
| 132 |
im_A_coords = torch.meshgrid(
|
| 133 |
(
|
| 134 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
| 135 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
| 136 |
)
|
| 137 |
)
|
| 138 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
|
| 425 |
corresps[ins] = {}
|
| 426 |
f1_s, f2_s = f1[ins], f2[ins]
|
| 427 |
if new_scale in self.proj:
|
| 428 |
+
with torch.autocast(device, self.amp_dtype):
|
| 429 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
| 430 |
|
| 431 |
if ins in coarse_scales:
|
|
|
|
| 645 |
device=None,
|
| 646 |
):
|
| 647 |
if device is None:
|
| 648 |
+
device = torch.device(device if torch.cuda.is_available() else "cpu")
|
| 649 |
from PIL import Image
|
| 650 |
|
| 651 |
if isinstance(im_A_path, (str, os.PathLike)):
|
|
|
|
| 741 |
# Create im_A meshgrid
|
| 742 |
im_A_coords = torch.meshgrid(
|
| 743 |
(
|
| 744 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=device),
|
| 745 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=device),
|
| 746 |
)
|
| 747 |
)
|
| 748 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
third_party/Roma/roma/models/transformer/__init__.py
CHANGED
|
@@ -7,6 +7,8 @@ from .layers.block import Block
|
|
| 7 |
from .layers.attention import MemEffAttention
|
| 8 |
from .dinov2 import vit_large
|
| 9 |
|
|
|
|
|
|
|
| 10 |
|
| 11 |
class TransformerDecoder(nn.Module):
|
| 12 |
def __init__(
|
|
@@ -51,7 +53,7 @@ class TransformerDecoder(nn.Module):
|
|
| 51 |
return self._scales.copy()
|
| 52 |
|
| 53 |
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
| 54 |
-
with torch.autocast(
|
| 55 |
B, C, H, W = gp_posterior.shape
|
| 56 |
x = torch.cat((gp_posterior, features), dim=1)
|
| 57 |
B, C, H, W = x.shape
|
|
|
|
| 7 |
from .layers.attention import MemEffAttention
|
| 8 |
from .dinov2 import vit_large
|
| 9 |
|
| 10 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 11 |
+
|
| 12 |
|
| 13 |
class TransformerDecoder(nn.Module):
|
| 14 |
def __init__(
|
|
|
|
| 53 |
return self._scales.copy()
|
| 54 |
|
| 55 |
def forward(self, gp_posterior, features, old_stuff, new_scale):
|
| 56 |
+
with torch.autocast(device, dtype=self.amp_dtype, enabled=self.amp):
|
| 57 |
B, C, H, W = gp_posterior.shape
|
| 58 |
x = torch.cat((gp_posterior, features), dim=1)
|
| 59 |
B, C, H, W = x.shape
|
third_party/Roma/roma/utils/local_correlation.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
def local_correlation(
|
| 6 |
feature0,
|
|
@@ -20,8 +22,8 @@ def local_correlation(
|
|
| 20 |
# If flow is None, assume feature0 and feature1 are aligned
|
| 21 |
coords = torch.meshgrid(
|
| 22 |
(
|
| 23 |
-
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=
|
| 24 |
-
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=
|
| 25 |
)
|
| 26 |
)
|
| 27 |
coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
|
|
@@ -30,10 +32,10 @@ def local_correlation(
|
|
| 30 |
local_window = torch.meshgrid(
|
| 31 |
(
|
| 32 |
torch.linspace(
|
| 33 |
-
-2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=
|
| 34 |
),
|
| 35 |
torch.linspace(
|
| 36 |
-
-2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=
|
| 37 |
),
|
| 38 |
)
|
| 39 |
)
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
|
| 4 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 5 |
+
|
| 6 |
|
| 7 |
def local_correlation(
|
| 8 |
feature0,
|
|
|
|
| 22 |
# If flow is None, assume feature0 and feature1 are aligned
|
| 23 |
coords = torch.meshgrid(
|
| 24 |
(
|
| 25 |
+
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
| 26 |
+
torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
|
| 27 |
)
|
| 28 |
)
|
| 29 |
coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
|
|
|
|
| 32 |
local_window = torch.meshgrid(
|
| 33 |
(
|
| 34 |
torch.linspace(
|
| 35 |
+
-2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
|
| 36 |
),
|
| 37 |
torch.linspace(
|
| 38 |
+
-2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
|
| 39 |
),
|
| 40 |
)
|
| 41 |
)
|