Spaces:
Runtime error
Runtime error
Commit
·
deab087
1
Parent(s):
9c1dc83
Update flow/flow_utils.py
Browse files- flow/flow_utils.py +8 -5
flow/flow_utils.py
CHANGED
|
@@ -12,6 +12,8 @@ sys.path.insert(0, gmflow_dir)
|
|
| 12 |
from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
|
| 13 |
from utils.utils import InputPadder # noqa: E702 E402
|
| 14 |
|
|
|
|
|
|
|
| 15 |
|
| 16 |
def coords_grid(b, h, w, homogeneous=False, device=None):
|
| 17 |
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
|
@@ -27,7 +29,7 @@ def coords_grid(b, h, w, homogeneous=False, device=None):
|
|
| 27 |
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
| 28 |
|
| 29 |
if device is not None:
|
| 30 |
-
grid = grid.to(
|
| 31 |
|
| 32 |
return grid
|
| 33 |
|
|
@@ -117,7 +119,8 @@ def get_warped_and_mask(flow_model,
|
|
| 117 |
if image3 is None:
|
| 118 |
image3 = image1
|
| 119 |
padder = InputPadder(image1.shape, padding_factor=8)
|
| 120 |
-
image1, image2 = padder.pad(image1[None].
|
|
|
|
| 121 |
results_dict = flow_model(image1,
|
| 122 |
image2,
|
| 123 |
attn_splits_list=[2],
|
|
@@ -150,8 +153,7 @@ class FlowCalc():
|
|
| 150 |
attention_type='swin',
|
| 151 |
ffn_dim_expansion=4,
|
| 152 |
num_transformer_layers=6,
|
| 153 |
-
).to(
|
| 154 |
-
|
| 155 |
checkpoint = torch.load(model_path,
|
| 156 |
map_location=lambda storage, loc: storage)
|
| 157 |
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
|
|
@@ -168,7 +170,8 @@ class FlowCalc():
|
|
| 168 |
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
|
| 169 |
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
|
| 170 |
padder = InputPadder(image1.shape, padding_factor=8)
|
| 171 |
-
image1, image2 = padder.pad(image1[None].
|
|
|
|
| 172 |
results_dict = self.model(image1,
|
| 173 |
image2,
|
| 174 |
attn_splits_list=[2],
|
|
|
|
| 12 |
from gmflow.gmflow import GMFlow # noqa: E702 E402 F401
|
| 13 |
from utils.utils import InputPadder # noqa: E702 E402
|
| 14 |
|
| 15 |
+
global_device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
| 16 |
+
|
| 17 |
|
| 18 |
def coords_grid(b, h, w, homogeneous=False, device=None):
|
| 19 |
y, x = torch.meshgrid(torch.arange(h), torch.arange(w)) # [H, W]
|
|
|
|
| 29 |
grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
|
| 30 |
|
| 31 |
if device is not None:
|
| 32 |
+
grid = grid.to(global_device)
|
| 33 |
|
| 34 |
return grid
|
| 35 |
|
|
|
|
| 119 |
if image3 is None:
|
| 120 |
image3 = image1
|
| 121 |
padder = InputPadder(image1.shape, padding_factor=8)
|
| 122 |
+
image1, image2 = padder.pad(image1[None].to(global_device),
|
| 123 |
+
image2[None].to(global_device))
|
| 124 |
results_dict = flow_model(image1,
|
| 125 |
image2,
|
| 126 |
attn_splits_list=[2],
|
|
|
|
| 153 |
attention_type='swin',
|
| 154 |
ffn_dim_expansion=4,
|
| 155 |
num_transformer_layers=6,
|
| 156 |
+
).to(global_device)
|
|
|
|
| 157 |
checkpoint = torch.load(model_path,
|
| 158 |
map_location=lambda storage, loc: storage)
|
| 159 |
weights = checkpoint['model'] if 'model' in checkpoint else checkpoint
|
|
|
|
| 170 |
image1 = torch.from_numpy(image1).permute(2, 0, 1).float()
|
| 171 |
image2 = torch.from_numpy(image2).permute(2, 0, 1).float()
|
| 172 |
padder = InputPadder(image1.shape, padding_factor=8)
|
| 173 |
+
image1, image2 = padder.pad(image1[None].to(global_device),
|
| 174 |
+
image2[None].to(global_device))
|
| 175 |
results_dict = self.model(image1,
|
| 176 |
image2,
|
| 177 |
attn_splits_list=[2],
|