Spaces:
Building
Building
File size: 4,493 Bytes
4187c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
from typing import Callable, Optional, Union, Sequence
import numpy as np
import torch
import torchvision.transforms.functional as tvf
import collections
from scipy.spatial.transform import Rotation
from ..utils.geometry import from_homogeneous, to_homogeneous
from ..utils.wrappers import Camera
def rectify_image(
image: torch.Tensor,
cam: Camera,
roll: float,
pitch: Optional[float] = None,
valid: Optional[torch.Tensor] = None,
):
*_, h, w = image.shape
grid = torch.meshgrid(
[torch.arange(w, device=image.device), torch.arange(h, device=image.device)],
indexing="xy",
)
grid = torch.stack(grid, -1).to(image.dtype)
if pitch is not None:
args = ("ZX", (roll, pitch))
else:
args = ("Z", roll)
R = Rotation.from_euler(*args, degrees=True).as_matrix()
R = torch.from_numpy(R).to(image)
grid_rect = to_homogeneous(cam.normalize(grid)) @ R.T
grid_rect = cam.denormalize(from_homogeneous(grid_rect))
grid_norm = (grid_rect + 0.5) / grid.new_tensor([w, h]) * 2 - 1
rectified = torch.nn.functional.grid_sample(
image[None],
grid_norm[None],
align_corners=False,
mode="bilinear",
).squeeze(0)
if valid is None:
valid = torch.all((grid_norm >= -1) & (grid_norm <= 1), -1)
else:
valid = (
torch.nn.functional.grid_sample(
valid[None, None].float(),
grid_norm[None],
align_corners=False,
mode="nearest",
)[0, 0]
> 0
)
return rectified, valid
def resize_image(
image: torch.Tensor,
size: Union[int, Sequence, np.ndarray],
fn: Optional[Callable] = None,
camera: Optional[Camera] = None,
valid: np.ndarray = None,
):
"""Resize an image to a fixed size, or according to max or min edge."""
*_, h, w = image.shape
if fn is not None:
assert isinstance(size, int)
scale = size / fn(h, w)
h_new, w_new = int(round(h * scale)), int(round(w * scale))
scale = (scale, scale)
else:
if isinstance(size, (collections.abc.Sequence, np.ndarray)):
w_new, h_new = size
elif isinstance(size, int):
w_new = h_new = size
else:
raise ValueError(f"Incorrect new size: {size}")
scale = (w_new / w, h_new / h)
if (w, h) != (w_new, h_new):
mode = tvf.InterpolationMode.BILINEAR
image = tvf.resize(image, (int(h_new), int(w_new)), interpolation=mode, antialias=True)
image.clip_(0, 1)
if camera is not None:
camera = camera.scale(scale)
if valid is not None:
valid = tvf.resize(
valid.unsqueeze(0),
(int(h_new), int(w_new)),
interpolation=tvf.InterpolationMode.NEAREST,
).squeeze(0)
ret = [image, scale]
if camera is not None:
ret.append(camera)
if valid is not None:
ret.append(valid)
return ret
def pad_image(
image: torch.Tensor,
size: Union[int, Sequence, np.ndarray],
camera: Optional[Camera] = None,
valid: torch.Tensor = None,
crop_and_center: bool = False,
):
if isinstance(size, int):
w_new = h_new = size
elif isinstance(size, (collections.abc.Sequence, np.ndarray)):
w_new, h_new = size
else:
raise ValueError(f"Incorrect new size: {size}")
*c, h, w = image.shape
if crop_and_center:
diff = np.array([w - w_new, h - h_new])
left, top = left_top = np.round(diff / 2).astype(int)
right, bottom = diff - left_top
else:
assert h <= h_new
assert w <= w_new
top = bottom = left = right = 0
slice_out = np.s_[..., : min(h, h_new), : min(w, w_new)]
slice_in = np.s_[
..., max(top, 0) : h - max(bottom, 0), max(left, 0) : w - max(right, 0)
]
if (w, h) == (w_new, h_new):
out = image
else:
out = torch.zeros((*c, h_new, w_new), dtype=image.dtype)
out[slice_out] = image[slice_in]
if camera is not None:
camera = camera.crop((max(left, 0), max(top, 0)), (w_new, h_new))
out_valid = torch.zeros((h_new, w_new), dtype=torch.bool)
out_valid[slice_out] = True if valid is None else valid[slice_in]
if camera is not None:
return out, out_valid, camera
else:
return out, out_valid
|