File size: 9,969 Bytes
f1586f7 6b9382c f1586f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 |
# Copyright 2024 DeepMind Technologies Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Pytorch model utilities."""
import math
from typing import Any, Sequence, Union
import numpy as np
import torch
import torch.nn.functional as F
def bilinear(x: torch.Tensor, resolution: tuple[int, int]) -> torch.Tensor:
"""Resizes a 5D tensor using bilinear interpolation.
Args:
x: A 5D tensor of shape (B, T, W, H, C) where B is batch size, T is
time, W is width, H is height, and C is the number of channels.
resolution: The target resolution as a tuple (new_width, new_height).
Returns:
The resized tensor.
"""
b, t, h, w, c = x.size()
x = x.permute(0, 1, 4, 2, 3).reshape(b, t * c, h, w)
x = F.interpolate(x, size=resolution, mode='bilinear', align_corners=False)
b, _, h, w = x.size()
x = x.reshape(b, t, c, h, w).permute(0, 1, 3, 4, 2)
return x
def map_coordinates_3d(
feats: torch.Tensor, coordinates: torch.Tensor
) -> torch.Tensor:
"""Maps 3D coordinates to corresponding features using bilinear interpolation.
Args:
feats: A 5D tensor of features with shape (B, W, H, D, C), where B is batch
size, W is width, H is height, D is depth, and C is the number of
channels.
coordinates: A 3D tensor of coordinates with shape (B, N, 3), where N is the
number of coordinates and the last dimension represents (W, H, D)
coordinates.
Returns:
The mapped features tensor.
"""
x = feats.permute(0, 4, 1, 2, 3)
y = coordinates[:, :, None, None, :].float().clone()
y[..., 0] = y[..., 0] + 0.5
y = 2 * (y / torch.tensor(x.shape[2:], device=y.device)) - 1
y = torch.flip(y, dims=(-1,))
out = (
F.grid_sample(
x, y, mode='bilinear', align_corners=False, padding_mode='border'
)
.squeeze(dim=(3, 4))
.permute(0, 2, 1)
)
return out
def map_coordinates_2d(
feats: torch.Tensor, coordinates: torch.Tensor
) -> torch.Tensor:
"""Maps 2D coordinates to feature maps using bilinear interpolation.
The function performs bilinear interpolation on the feature maps (`feats`)
at the specified `coordinates`. The coordinates are normalized between
-1 and 1 The result is a tensor of sampled features corresponding
to these coordinates.
Args:
feats (Tensor): A 5D tensor of shape (N, T, H, W, C) representing feature
maps, where N is the batch size, T is the number of frames, H and W are
height and width, and C is the number of channels.
coordinates (Tensor): A 5D tensor of shape (N, P, T, S, XY) representing
coordinates, where N is the batch size, P is the number of points, T is
the number of frames, S is the number of samples, and XY represents the 2D
coordinates.
Returns:
Tensor: A 5D tensor of the sampled features corresponding to the
given coordinates, of shape (N, P, T, S, C).
"""
n, t, h, w, c = feats.shape
x = feats.permute(0, 1, 4, 2, 3).view(n * t, c, h, w)
n, p, t, s, xy = coordinates.shape
y = coordinates.permute(0, 2, 1, 3, 4).reshape(n * t, p, s, xy)
y = 2 * (y / torch.tensor([h, w], device=feats.device)) - 1
y = torch.flip(y, dims=(-1,)).float()
out = F.grid_sample(
x, y, mode='bilinear', align_corners=False, padding_mode='zeros'
)
_, c, _, _ = out.shape
out = out.permute(0, 2, 3, 1).view(n, t, p, s, c).permute(0, 2, 1, 3, 4)
return out
def soft_argmax_heatmap_batched(softmax_val, threshold=5):
"""Test if two image resolutions are the same."""
b, h, w, d1, d2 = softmax_val.shape
y, x = torch.meshgrid(
torch.arange(d1, device=softmax_val.device),
torch.arange(d2, device=softmax_val.device),
indexing='ij',
)
coords = torch.stack([x + 0.5, y + 0.5], dim=-1).to(softmax_val.device)
softmax_val_flat = softmax_val.reshape(b, h, w, -1)
argmax_pos = torch.argmax(softmax_val_flat, dim=-1)
pos = coords.reshape(-1, 2)[argmax_pos]
valid = (
torch.sum(
torch.square(
coords[None, None, None, :, :, :] - pos[:, :, :, None, None, :]
),
dim=-1,
keepdims=True,
)
< threshold**2
)
weighted_sum = torch.sum(
coords[None, None, None, :, :, :]
* valid
* softmax_val[:, :, :, :, :, None],
dim=(3, 4),
)
sum_of_weights = torch.maximum(
torch.sum(valid * softmax_val[:, :, :, :, :, None], dim=(3, 4)),
torch.tensor(1e-12, device=softmax_val.device),
)
return weighted_sum / sum_of_weights
def heatmaps_to_points(
all_pairs_softmax,
image_shape,
threshold=5,
query_points=None,
):
"""Convert heatmaps to points using soft argmax."""
out_points = soft_argmax_heatmap_batched(all_pairs_softmax, threshold)
feature_grid_shape = all_pairs_softmax.shape[1:]
# Note: out_points is now [x, y]; we need to divide by [width, height].
# image_shape[3] is width and image_shape[2] is height.
out_points = convert_grid_coordinates(
out_points,
feature_grid_shape[3:1:-1],
image_shape[3:1:-1],
)
assert feature_grid_shape[1] == image_shape[1]
if query_points is not None:
# The [..., 0:1] is because we only care about the frame index.
query_frame = convert_grid_coordinates(
query_points.detach(),
image_shape[1:4],
feature_grid_shape[1:4],
coordinate_format='tyx',
)[..., 0:1]
query_frame = torch.round(query_frame)
frame_indices = torch.arange(image_shape[1], device=query_frame.device)[
None, None, :
]
is_query_point = query_frame == frame_indices
is_query_point = is_query_point[:, :, :, None]
out_points = (
out_points * ~is_query_point
+ torch.flip(query_points[:, :, None], dims=(-1,))[..., 0:2]
* is_query_point
)
return out_points
def is_same_res(r1, r2):
"""Test if two image resolutions are the same."""
return all([x == y for x, y in zip(r1, r2)])
def convert_grid_coordinates(
coords: torch.Tensor,
input_grid_size: Sequence[int],
output_grid_size: Sequence[int],
coordinate_format: str = 'xy',
) -> torch.Tensor:
"""Convert grid coordinates to correct format."""
if isinstance(input_grid_size, tuple):
input_grid_size = torch.tensor(input_grid_size, device=coords.device)
if isinstance(output_grid_size, tuple):
output_grid_size = torch.tensor(output_grid_size, device=coords.device)
if coordinate_format == 'xy':
if input_grid_size.shape[0] != 2 or output_grid_size.shape[0] != 2:
raise ValueError(
'If coordinate_format is xy, the shapes must be length 2.'
)
elif coordinate_format == 'tyx':
if input_grid_size.shape[0] != 3 or output_grid_size.shape[0] != 3:
raise ValueError(
'If coordinate_format is tyx, the shapes must be length 3.'
)
if input_grid_size[0] != output_grid_size[0]:
raise ValueError('converting frame count is not supported.')
else:
raise ValueError('Recognized coordinate formats are xy and tyx.')
position_in_grid = coords
position_in_grid = position_in_grid * output_grid_size / input_grid_size
return position_in_grid
def generate_default_resolutions(full_size, train_size, num_levels=None):
"""Generate a list of logarithmically-spaced resolutions.
Generated resolutions are between train_size and full_size, inclusive, with
num_levels different resolutions total. Useful for generating the input to
refinement_resolutions in PIPs.
Args:
full_size: 2-tuple of ints. The full image size desired.
train_size: 2-tuple of ints. The smallest refinement level. Should
typically match the training resolution, which is (256, 256) for TAPIR.
num_levels: number of levels. Typically each resolution should be less than
twice the size of prior resolutions.
Returns:
A list of resolutions.
"""
if all([x == y for x, y in zip(train_size, full_size)]):
return [train_size]
if num_levels is None:
size_ratio = np.array(full_size) / np.array(train_size)
num_levels = int(np.ceil(np.max(np.log2(size_ratio))) + 1)
if num_levels <= 1:
return [train_size]
h, w = full_size[0:2]
if h % 8 != 0 or w % 8 != 0:
print(
'Warning: output size is not a multiple of 8. Final layer '
+ 'will round size down.'
)
ll_h, ll_w = train_size[0:2]
sizes = []
for i in range(num_levels):
size = (
int(round((ll_h * (h / ll_h) ** (i / (num_levels - 1))) // 8)) * 8,
int(round((ll_w * (w / ll_w) ** (i / (num_levels - 1))) // 8)) * 8,
)
sizes.append(size)
return sizes
class Conv2dSamePadding(torch.nn.Conv2d):
def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
ih, iw = x.size()[-2:]
pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
if pad_h > 0 or pad_w > 0:
x = F.pad(
x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
)
return F.conv2d(
x,
self.weight,
self.bias,
self.stride,
# self.padding,
0,
self.dilation,
self.groups,
) |