hangg-sai's picture
Initial commit
a342aa8
import collections
import json
import math
import os
import re
import threading
from typing import List, Literal, Optional, Tuple, Union
import gradio as gr
from colorama import Fore, Style, init
init(autoreset=True)
import imageio.v3 as iio
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import repeat
from PIL import Image
from tqdm.auto import tqdm
from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose
from seva.sampling import (
EulerEDMSampler,
MultiviewCFG,
MultiviewTemporalCFG,
VanillaCFG,
)
from seva.utils import seed_everything
try:
# Check if version string contains 'dev' or 'nightly'
version = torch.__version__
IS_TORCH_NIGHTLY = "dev" in version
if IS_TORCH_NIGHTLY:
torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment]
torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment]
torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment]
except Exception:
IS_TORCH_NIGHTLY = False
def pad_indices(
input_indices: List[int],
test_indices: List[int],
T: int,
padding_mode: Literal["first", "last", "none"] = "last",
):
assert padding_mode in ["last", "none"], "`first` padding is not supported yet."
if padding_mode == "last":
padded_indices = [
i for i in range(T) if i not in (input_indices + test_indices)
]
else:
padded_indices = []
input_selects = list(range(len(input_indices)))
test_selects = list(range(len(test_indices)))
if max(input_indices) > max(test_indices):
# last elem from input
input_selects += [input_selects[-1]] * len(padded_indices)
input_indices = input_indices + padded_indices
sorted_inds = np.argsort(input_indices)
input_indices = [input_indices[ind] for ind in sorted_inds]
input_selects = [input_selects[ind] for ind in sorted_inds]
else:
# last elem from test
test_selects += [test_selects[-1]] * len(padded_indices)
test_indices = test_indices + padded_indices
sorted_inds = np.argsort(test_indices)
test_indices = [test_indices[ind] for ind in sorted_inds]
test_selects = [test_selects[ind] for ind in sorted_inds]
if padding_mode == "last":
input_maps = np.array([-1] * T)
test_maps = np.array([-1] * T)
else:
input_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
test_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
input_maps[input_indices] = input_selects
test_maps[test_indices] = test_selects
return input_indices, test_indices, input_maps, test_maps
def assemble(
input,
test,
input_maps,
test_maps,
):
T = len(input_maps)
assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0)
assembled[input_maps != -1] = input[input_maps[input_maps != -1]]
assembled[test_maps != -1] = test[test_maps[test_maps != -1]]
assert np.logical_xor(input_maps != -1, test_maps != -1).all()
return assembled
def get_resizing_factor(
target_shape: Tuple[int, int], # H, W
current_shape: Tuple[int, int], # H, W
cover_target: bool = True,
# If True, the output shape will fully cover the target shape.
# If No, the target shape will fully cover the output shape.
) -> float:
r_bound = target_shape[1] / target_shape[0]
aspect_r = current_shape[1] / current_shape[0]
if r_bound >= 1.0:
if cover_target:
if aspect_r >= r_bound:
factor = min(target_shape) / min(current_shape)
elif aspect_r < 1.0:
factor = max(target_shape) / min(current_shape)
else:
factor = max(target_shape) / max(current_shape)
else:
if aspect_r >= r_bound:
factor = max(target_shape) / max(current_shape)
elif aspect_r < 1.0:
factor = min(target_shape) / max(current_shape)
else:
factor = min(target_shape) / min(current_shape)
else:
if cover_target:
if aspect_r <= r_bound:
factor = min(target_shape) / min(current_shape)
elif aspect_r > 1.0:
factor = max(target_shape) / min(current_shape)
else:
factor = max(target_shape) / max(current_shape)
else:
if aspect_r <= r_bound:
factor = max(target_shape) / max(current_shape)
elif aspect_r > 1.0:
factor = min(target_shape) / max(current_shape)
else:
factor = min(target_shape) / min(current_shape)
return factor
def get_unique_embedder_keys_from_conditioner(conditioner):
keys = [x.input_key for x in conditioner.embedders if x.input_key is not None]
keys = [item for sublist in keys for item in sublist] # Flatten list
return set(keys)
def get_wh_with_fixed_shortest_side(w, h, size):
# size is smaller or equal to zero, we return original w h
if size is None or size <= 0:
return w, h
if w < h:
new_w = size
new_h = int(size * h / w)
else:
new_h = size
new_w = int(size * w / h)
return new_w, new_h
def load_img_and_K(
image_path_or_size: Union[str, torch.Size],
size: Optional[Union[int, Tuple[int, int]]],
scale: float = 1.0,
center: Tuple[float, float] = (0.5, 0.5),
K: torch.Tensor | None = None,
size_stride: int = 1,
center_crop: bool = False,
image_as_tensor: bool = True,
context_rgb: np.ndarray | None = None,
device: str = "cuda",
):
if isinstance(image_path_or_size, torch.Size):
image = Image.new("RGBA", image_path_or_size[::-1])
else:
image = Image.open(image_path_or_size).convert("RGBA")
w, h = image.size
if size is None:
size = (w, h)
image = np.array(image).astype(np.float32) / 255
if image.shape[-1] == 4:
rgb, alpha = image[:, :, :3], image[:, :, 3:]
if context_rgb is not None:
image = rgb * alpha + context_rgb * (1 - alpha)
else:
image = rgb * alpha + (1 - alpha)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)
if isinstance(size, (tuple, list)):
# => if size is a tuple or list, we first rescale to fully cover the `size`
# area and then crop the `size` area from the rescale image
W, H = size
else:
# => if size is int, we rescale the image to fit the shortest side to size
# => if size is None, no rescaling is applied
W, H = get_wh_with_fixed_shortest_side(w, h, size)
W, H = (
math.floor(W / size_stride + 0.5) * size_stride,
math.floor(H / size_stride + 0.5) * size_stride,
)
rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
if scale < 1.0:
pw = math.ceil((W - resize_size[1]) * 0.5)
ph = math.ceil((H - resize_size[0]) * 0.5)
image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)
cy_center = int(center[1] * image.shape[-2])
cx_center = int(center[0] * image.shape[-1])
if center_crop:
side = min(H, W)
ct = max(0, cy_center - side // 2)
cl = max(0, cx_center - side // 2)
ct = min(ct, image.shape[-2] - side)
cl = min(cl, image.shape[-1] - side)
image = TF.crop(image, top=ct, left=cl, height=side, width=side)
else:
ct = max(0, cy_center - H // 2)
cl = max(0, cx_center - W // 2)
ct = min(ct, image.shape[-2] - H)
cl = min(cl, image.shape[-1] - W)
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
if K is not None:
K = K.clone()
if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K
else:
K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K
K[:2, 2] -= K.new_tensor([cl, ct])
if image_as_tensor:
# tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
image = image.to(device) * 2.0 - 1.0
else:
# PIL Image with values ranging from (0, 255)
image = image.permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).astype(np.uint8))
return image, K
def transform_img_and_K(
image: torch.Tensor,
size: Union[int, Tuple[int, int]],
scale: float = 1.0,
center: Tuple[float, float] = (0.5, 0.5),
K: torch.Tensor | None = None,
size_stride: int = 1,
mode: str = "crop",
):
assert mode in [
"crop",
"pad",
"stretch",
], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"
h, w = image.shape[-2:]
if isinstance(size, (tuple, list)):
# => if size is a tuple or list, we first rescale to fully cover the `size`
# area and then crop the `size` area from the rescale image
W, H = size
else:
# => if size is int, we rescale the image to fit the shortest side to size
# => if size is None, no rescaling is applied
W, H = get_wh_with_fixed_shortest_side(w, h, size)
W, H = (
math.floor(W / size_stride + 0.5) * size_stride,
math.floor(H / size_stride + 0.5) * size_stride,
)
if mode == "stretch":
rh, rw = H, W
else:
rfs = get_resizing_factor(
(H, W),
(h, w),
cover_target=mode != "pad",
)
(rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]
rh, rw = int(rh / scale), int(rw / scale)
image = torch.nn.functional.interpolate(
image, (rh, rw), mode="area", antialias=False
)
cy_center = int(center[1] * image.shape[-2])
cx_center = int(center[0] * image.shape[-1])
if mode != "pad":
ct = max(0, cy_center - H // 2)
cl = max(0, cx_center - W // 2)
ct = min(ct, image.shape[-2] - H)
cl = min(cl, image.shape[-1] - W)
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
pl, pt = 0, 0
else:
pt = max(0, H // 2 - cy_center)
pl = max(0, W // 2 - cx_center)
pb = max(0, H - pt - image.shape[-2])
pr = max(0, W - pl - image.shape[-1])
image = TF.pad(
image,
[pl, pt, pr, pb],
)
cl, ct = 0, 0
if K is not None:
K = K.clone()
# K[:, :2, 2] += K.new_tensor([pl, pt])
if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K
else:
K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K
K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])
return image, K
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def load_model(model, device: str = "cuda"):
model.to(device)
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def infer_prior_stats(
T,
num_input_frames,
num_total_frames,
version_dict,
):
options = version_dict["options"]
chunk_strategy = options.get("chunk_strategy", "nearest")
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
# get traj_prior_c2ws for 2-pass sampling
if chunk_strategy.startswith("interp"):
# Start and end have alreay taken up two slots
# +1 means we need X + 1 prior frames to bound X times forwards for all test frames
# Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the
# generated frames due to insufficient prior frames. This option is effective for
# complicated trajectory and when `interp` strategy is used (usually semi-dense-view
# regime). Recommended range is [1.0 (default), 1.5].
if num_input_frames >= options.get("num_input_semi_dense", 9):
num_prior_frames = (
math.ceil(
num_total_frames
/ (T_second_pass - 2)
* options.get("num_prior_frames_ratio", 1.0)
)
+ 1
)
if num_prior_frames + num_input_frames < T_first_pass:
num_prior_frames = T_first_pass - num_input_frames
num_prior_frames = max(
num_prior_frames,
options.get("num_prior_frames", 0),
)
T_first_pass = num_prior_frames + num_input_frames
if "gt" in chunk_strategy:
T_second_pass = T_second_pass + num_input_frames
# Dynamically update context window length.
version_dict["T"] = [T_first_pass, T_second_pass]
else:
num_prior_frames = (
math.ceil(
num_total_frames
/ (
T_second_pass
- 2
- (num_input_frames if "gt" in chunk_strategy else 0)
)
* options.get("num_prior_frames_ratio", 1.0)
)
+ 1
)
if num_prior_frames + num_input_frames < T_first_pass:
num_prior_frames = T_first_pass - num_input_frames
num_prior_frames = max(
num_prior_frames,
options.get("num_prior_frames", 0),
)
else:
num_prior_frames = max(
T_first_pass - num_input_frames,
options.get("num_prior_frames", 0),
)
if num_input_frames >= options.get("num_input_semi_dense", 9):
T_first_pass = num_prior_frames + num_input_frames
# Dynamically update context window length.
version_dict["T"] = [T_first_pass, T_second_pass]
return num_prior_frames
def infer_prior_inds(
c2ws,
num_prior_frames,
input_frame_indices,
options,
):
chunk_strategy = options.get("chunk_strategy", "nearest")
if chunk_strategy.startswith("interp"):
prior_frame_indices = np.array(
[i for i in range(c2ws.shape[0]) if i not in input_frame_indices]
)
prior_frame_indices = prior_frame_indices[
np.ceil(
np.linspace(
0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True
)
).astype(int)
] # having a ceil here is actually safer for corner case
else:
prior_frame_indices = []
while len(prior_frame_indices) < num_prior_frames:
closest_distance = np.abs(
np.arange(c2ws.shape[0])[None]
- np.concatenate(
[np.array(input_frame_indices), np.array(prior_frame_indices)]
)[:, None]
).min(0)
prior_frame_indices.append(np.argsort(closest_distance)[-1])
return np.sort(prior_frame_indices)
def compute_relative_inds(
source_inds,
target_inds,
):
assert len(source_inds) > 2
# compute relative indices of target_inds within source_inds
relative_inds = []
for ind in target_inds:
if ind in source_inds:
relative_ind = int(np.where(source_inds == ind)[0][0])
elif ind < source_inds[0]:
# extrapolate
relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0]))
elif ind > source_inds[-1]:
# extrapolate
relative_ind = len(source_inds) + (
(ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2])
)
else:
# interpolate
lower_inds = source_inds[source_inds < ind]
upper_inds = source_inds[source_inds > ind]
if len(lower_inds) > 0 and len(upper_inds) > 0:
lower_ind = lower_inds[-1]
upper_ind = upper_inds[0]
relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0])
relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0])
relative_ind = relative_lower_ind + (ind - lower_ind) / (
upper_ind - lower_ind
) * (relative_upper_ind - relative_lower_ind)
else:
# Out of range
relative_inds.append(float("nan")) # Or some other placeholder
relative_inds.append(relative_ind)
return relative_inds
def find_nearest_source_inds(
source_c2ws,
target_c2ws,
nearest_num=1,
mode="translation",
):
dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy()
sorted_inds = np.argsort(dists, axis=0).T
return sorted_inds[:, :nearest_num]
def chunk_input_and_test(
T,
input_c2ws,
test_c2ws,
input_ords, # orders
test_ords, # orders
options,
task: str = "img2img",
chunk_strategy: str = "gt",
gt_input_inds: list = [],
):
M, N = input_c2ws.shape[0], test_c2ws.shape[0]
chunks = []
if chunk_strategy.startswith("gt"):
assert len(gt_input_inds) < T, (
f"Number of gt input frames {len(gt_input_inds)} should be "
f"less than {T} when `gt` chunking strategy is used."
)
assert (
list(range(M)) == gt_input_inds
), "All input_c2ws should be gt when `gt` chunking strategy is used."
# LEGACY CHUNKING STRATEGY
# num_test_per_chunk = T - len(gt_input_inds)
# test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds]
# for i in range(0, test_c2ws.shape[0], num_test_per_chunk):
# chunk = ["NULL"] * T
# for j, k in enumerate(gt_input_inds):
# chunk[k] = f"!{j:03d}"
# for j, k in enumerate(
# test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]]
# ):
# chunk[k] = f">{i + j:03d}"
# chunks.append(chunk)
num_test_seen = 0
while num_test_seen < N:
chunk = [f"!{i:03d}" for i in gt_input_inds]
if chunk_strategy != "gt" and num_test_seen > 0:
pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33)
if (N - num_test_seen) >= math.floor(
(T - len(gt_input_inds)) * pseudo_num_ratio
):
pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio)
else:
pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen)
pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000))
if "ltr" in chunk_strategy:
chunk.extend(
[
f"!{i + len(gt_input_inds):03d}"
for i in range(num_test_seen - pseudo_num, num_test_seen)
]
)
elif "nearest" in chunk_strategy:
source_inds = np.concatenate(
[
find_nearest_source_inds(
test_c2ws[:num_test_seen],
test_c2ws[num_test_seen:],
nearest_num=1, # pseudo_num,
mode="rotation",
),
find_nearest_source_inds(
test_c2ws[:num_test_seen],
test_c2ws[num_test_seen:],
nearest_num=1, # pseudo_num,
mode="translation",
),
],
axis=1,
)
####### [HACK ALERT] keep running until pseudo num is stablized ########
temp_pseudo_num = pseudo_num
while True:
nearest_source_inds = np.concatenate(
[
np.sort(
[
ind
for (ind, _) in collections.Counter(
[
item
for item in source_inds[
: T
- len(gt_input_inds)
- temp_pseudo_num
]
.flatten()
.tolist()
if item
!= (
num_test_seen - 1
) # exclude the last one here
]
).most_common(pseudo_num - 1)
],
).astype(int),
[num_test_seen - 1], # always keep the last one
]
)
if len(nearest_source_inds) >= temp_pseudo_num:
break # stablized
else:
temp_pseudo_num = len(nearest_source_inds)
pseudo_num = len(nearest_source_inds)
########################################################################
chunk.extend(
[f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds]
)
else:
raise NotImplementedError(
f"Chunking strategy {chunk_strategy} for the first pass is not implemented."
)
chunk.extend(
[
f">{i:03d}"
for i in range(
num_test_seen,
min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N),
)
]
)
else:
chunk.extend(
[
f">{i:03d}"
for i in range(
num_test_seen,
min(num_test_seen + T - len(gt_input_inds), N),
)
]
)
num_test_seen += sum([1 for c in chunk if c.startswith(">")])
if len(chunk) < T:
chunk.extend(["NULL"] * (T - len(chunk)))
chunks.append(chunk)
elif chunk_strategy.startswith("nearest"):
input_imgs = np.array([f"!{i:03d}" for i in range(M)])
test_imgs = np.array([f">{i:03d}" for i in range(N)])
match = re.match(r"^nearest-(\d+)$", chunk_strategy)
if match:
nearest_num = int(match.group(1))
assert (
nearest_num < T
), f"Nearest number of {nearest_num} should be less than {T}."
source_inds = find_nearest_source_inds(
input_c2ws,
test_c2ws,
nearest_num=nearest_num,
mode="translation", # during the second pass, consider translation only is enough
)
for i in range(0, N, T - nearest_num):
nearest_source_inds = np.sort(
[
ind
for (ind, _) in collections.Counter(
source_inds[i : i + T - nearest_num].flatten().tolist()
).most_common(nearest_num)
]
)
chunk = (
input_imgs[nearest_source_inds].tolist()
+ test_imgs[i : i + T - nearest_num].tolist()
)
chunks.append(chunk + ["NULL"] * (T - len(chunk)))
else:
# do not always condition on gt cond frames
if "gt" not in chunk_strategy:
gt_input_inds = []
source_inds = find_nearest_source_inds(
input_c2ws,
test_c2ws,
nearest_num=1,
mode="translation", # during the second pass, consider translation only is enough
)[:, 0]
test_inds_per_input = {}
for test_idx, input_idx in enumerate(source_inds):
if input_idx not in test_inds_per_input:
test_inds_per_input[input_idx] = []
test_inds_per_input[input_idx].append(test_idx)
num_test_seen = 0
chunk = input_imgs[gt_input_inds].tolist()
candidate_input_inds = sorted(list(test_inds_per_input.keys()))
while num_test_seen < N:
input_idx = candidate_input_inds[0]
test_inds = test_inds_per_input[input_idx]
input_is_cond = input_idx in gt_input_inds
prefix_inds = [] if input_is_cond else [input_idx]
if len(chunk) == T - len(prefix_inds) or not candidate_input_inds:
if chunk:
chunk += ["NULL"] * (T - len(chunk))
chunks.append(chunk)
chunk = input_imgs[gt_input_inds].tolist()
if num_test_seen >= N:
break
continue
candidate_chunk = (
input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist()
)
space_left = T - len(chunk)
if len(candidate_chunk) <= space_left:
chunk.extend(candidate_chunk)
num_test_seen += len(test_inds)
candidate_input_inds.pop(0)
else:
chunk.extend(candidate_chunk[:space_left])
num_input_idx = 0 if input_is_cond else 1
num_test_seen += space_left - num_input_idx
test_inds_per_input[input_idx] = test_inds[
space_left - num_input_idx :
]
if len(chunk) == T:
chunks.append(chunk)
chunk = input_imgs[gt_input_inds].tolist()
if chunk and chunk != input_imgs[gt_input_inds].tolist():
chunks.append(chunk + ["NULL"] * (T - len(chunk)))
elif chunk_strategy.startswith("interp"):
# `interp` chunk requires ordering info
assert input_ords is not None and test_ords is not None, (
"When using `interp` chunking strategy, ordering of input "
"and test frames should be provided."
)
# if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not
# use input views since their order info within target views is unknown
if "img2trajvid" in task:
assert (
list(range(len(gt_input_inds))) == gt_input_inds
), "`img2trajvid` task should put `gt_input_inds` in start."
input_c2ws = input_c2ws[
[ind for ind in range(M) if ind not in gt_input_inds]
]
input_ords = [
input_ords[ind] for ind in range(M) if ind not in gt_input_inds
]
M = input_c2ws.shape[0]
input_ords = [0] + input_ords # this is a hack accounting for test views
# before the first input view
input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included
# in the last forward when input_ords[-1] == test_ords[-1]
input_ords = np.array(input_ords)[:, None]
input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)])
test_ords = np.array(test_ords)[None]
in_stop_ranges = np.logical_and(
np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0),
np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0),
) # (M, N)
assert (in_stop_ranges.sum(1) <= T - 2).all(), (
"More input frames need to be sampled during the first pass to ensure "
f"#test frames during each forard in the second pass will not exceed {T - 2}."
)
if input_ords[1, 0] <= test_ords[0, 0]:
assert not in_stop_ranges[0].any()
if input_ords[-1, 0] >= test_ords[0, -1]:
assert not in_stop_ranges[-1].any()
gt_chunk = (
[f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else []
)
chunk = gt_chunk + []
# any test views before the first input views
if in_stop_ranges[0].any():
for j, in_range in enumerate(in_stop_ranges[0]):
if in_range:
chunk.append(f">{j:03d}")
in_stop_ranges = in_stop_ranges[1:]
i = 0
base_i = len(gt_input_inds) if "img2trajvid" in task else 0
chunk.append(f"!{i + base_i:03d}")
while i < len(in_stop_ranges):
in_stop_range = in_stop_ranges[i]
if not in_stop_range.any():
i += 1
continue
input_left = i + 1 < M
space_left = T - len(chunk)
if sum(in_stop_range) + input_left <= space_left:
for j, in_range in enumerate(in_stop_range):
if in_range:
chunk.append(f">{j:03d}")
i += 1
if input_left:
chunk.append(f"!{i + base_i:03d}")
else:
chunk += ["NULL"] * space_left
chunks.append(chunk)
chunk = gt_chunk + [f"!{i + base_i:03d}"]
if len(chunk) > 1:
chunk += ["NULL"] * (T - len(chunk))
chunks.append(chunk)
else:
raise NotImplementedError
(
input_inds_per_chunk,
input_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
) = (
[],
[],
[],
[],
)
for chunk in chunks:
input_inds = [
int(img.removeprefix("!")) for img in chunk if img.startswith("!")
]
input_sels = [chunk.index(img) for img in chunk if img.startswith("!")]
test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")]
test_sels = [chunk.index(img) for img in chunk if img.startswith(">")]
input_inds_per_chunk.append(input_inds)
input_sels_per_chunk.append(input_sels)
test_inds_per_chunk.append(test_inds)
test_sels_per_chunk.append(test_sels)
if options.get("sampler_verbose", True):
def colorize(item):
if item.startswith("!"):
return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!'
elif item.startswith(">"):
return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>'
return item # Default color if neither '!' nor '>'
print("\nchunks:")
for chunk in chunks:
print(", ".join(colorize(item) for item in chunk))
return (
chunks,
input_inds_per_chunk, # ordering of input in raw sequence
input_sels_per_chunk, # ordering of input in one-forward sequence of length T
test_inds_per_chunk, # ordering of test in raw sequence
test_sels_per_chunk, # oredering of test in one-forward sequence of length T
)
def is_k_in_dict(d, k):
return any(map(lambda x: x.startswith(k), d.keys()))
def get_k_from_dict(d, k):
media_d = {}
for key, value in d.items():
if key == k:
return value
if key.startswith(k):
media = key.split("/")[-1]
if media == "raw":
return value
media_d[media] = value
if len(media_d) == 0:
return torch.tensor([])
assert (
len(media_d) == 1
), f"multiple media found in {d} for key {k}: {media_d.keys()}"
return media_d[media]
def update_kv_for_dict(d, k, v):
for key in d.keys():
if key.startswith(k):
d[key] = v
return d
def extend_dict(ds, d):
for key in d.keys():
if key in ds:
ds[key] = torch.cat([ds[key], d[key]], 0)
else:
ds[key] = d[key]
return ds
def replace_or_include_input_for_dict(
samples,
test_indices,
imgs,
c2w,
K,
):
samples_new = {}
for sample, value in samples.items():
if "rgb" in sample:
imgs[test_indices] = (
value[test_indices] if value.shape[0] == imgs.shape[0] else value
).to(device=imgs.device, dtype=imgs.dtype)
samples_new[sample] = imgs
elif "c2w" in sample:
c2w[test_indices] = (
value[test_indices] if value.shape[0] == c2w.shape[0] else value
).to(device=c2w.device, dtype=c2w.dtype)
samples_new[sample] = c2w
elif "intrinsics" in sample:
K[test_indices] = (
value[test_indices] if value.shape[0] == K.shape[0] else value
).to(device=K.device, dtype=K.dtype)
samples_new[sample] = K
else:
samples_new[sample] = value
return samples_new
def decode_output(
samples,
T,
indices=None,
):
# decode model output into dict if it is not
if isinstance(samples, dict):
# model with postprocessor and outputs dict
for sample, value in samples.items():
if isinstance(value, torch.Tensor):
value = value.detach().cpu()
elif isinstance(value, np.ndarray):
value = torch.from_numpy(value)
else:
value = torch.tensor(value)
if indices is not None and value.shape[0] == T:
value = value[indices]
samples[sample] = value
else:
# model without postprocessor and outputs tensor (rgb)
samples = samples.detach().cpu()
if indices is not None and samples.shape[0] == T:
samples = samples[indices]
samples = {"samples-rgb/image": samples}
return samples
def save_output(
samples,
save_path,
video_save_fps=2,
):
os.makedirs(save_path, exist_ok=True)
for sample in samples:
media_type = "video"
if "/" in sample:
sample_, media_type = sample.split("/")
else:
sample_ = sample
value = samples[sample]
if isinstance(value, torch.Tensor):
value = value.detach().cpu()
elif isinstance(value, np.ndarray):
value = torch.from_numpy(value)
else:
value = torch.tensor(value)
if media_type == "image":
value = (value.permute(0, 2, 3, 1) + 1) / 2.0
value = (value * 255).clamp(0, 255).to(torch.uint8)
iio.imwrite(
os.path.join(save_path, f"{sample_}.mp4")
if sample_
else f"{save_path}.mp4",
value,
fps=video_save_fps,
macro_block_size=1,
ffmpeg_log_level="error",
)
os.makedirs(os.path.join(save_path, sample_), exist_ok=True)
for i, s in enumerate(value):
iio.imwrite(
os.path.join(save_path, sample_, f"{i:03d}.png"),
s,
)
elif media_type == "video":
value = (value.permute(0, 2, 3, 1) + 1) / 2.0
value = (value * 255).clamp(0, 255).to(torch.uint8)
iio.imwrite(
os.path.join(save_path, f"{sample_}.mp4"),
value,
fps=video_save_fps,
macro_block_size=1,
ffmpeg_log_level="error",
)
elif media_type == "raw":
torch.save(
value,
os.path.join(save_path, f"{sample_}.pt"),
)
else:
pass
def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks):
import os.path as osp
out_frames = []
for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks):
out_frame = {
"fl_x": K[0][0].item(),
"fl_y": K[1][1].item(),
"cx": K[0][2].item(),
"cy": K[1][2].item(),
"w": img_wh[0].item(),
"h": img_wh[1].item(),
"file_path": f"./{osp.relpath(img_path, start=save_path)}"
if img_path is not None
else None,
"transform_matrix": c2w.tolist(),
}
out_frames.append(out_frame)
out = {
# "camera_model": "PINHOLE",
"orientation_override": "none",
"frames": out_frames,
}
with open(osp.join(save_path, "transforms.json"), "w") as of:
json.dump(out, of, indent=5)
class GradioTrackedSampler(EulerEDMSampler):
"""
A thin wrapper around the EulerEDMSampler that allows tracking progress and
aborting sampling for gradio demo.
"""
def __init__(self, abort_event: threading.Event, *args, **kwargs):
super().__init__(*args, **kwargs)
self.abort_event = abort_event
def __call__( # type: ignore
self,
denoiser,
x: torch.Tensor,
scale: float | torch.Tensor,
cond: dict,
uc: dict | None = None,
num_steps: int | None = None,
verbose: bool = True,
global_pbar: gr.Progress | None = None,
**guider_kwargs,
) -> torch.Tensor | None:
uc = cond if uc is None else uc
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x,
cond,
uc,
num_steps,
)
for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
scale,
cond,
uc,
gamma,
**guider_kwargs,
)
# Allow tracking progress in gradio demo.
if global_pbar is not None:
global_pbar.update()
# Allow aborting sampling in gradio demo.
if self.abort_event.is_set():
return None
return x
def create_samplers(
guider_types: int | list[int],
discretization,
num_frames: list[int] | None,
num_steps: int,
cfg_min: float = 1.0,
device: str | torch.device = "cuda",
abort_event: threading.Event | None = None,
):
guider_mapping = {
0: VanillaCFG,
1: MultiviewCFG,
2: MultiviewTemporalCFG,
}
samplers = []
if not isinstance(guider_types, (list, tuple)):
guider_types = [guider_types]
for i, guider_type in enumerate(guider_types):
if guider_type not in guider_mapping:
raise ValueError(
f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}"
)
guider_cls = guider_mapping[guider_type]
guider_args = ()
if guider_type > 0:
guider_args += (cfg_min,)
if guider_type == 2:
assert num_frames is not None
guider_args = (num_frames[i], cfg_min)
guider = guider_cls(*guider_args)
if abort_event is not None:
sampler = GradioTrackedSampler(
abort_event,
discretization=discretization,
guider=guider,
num_steps=num_steps,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=device,
)
else:
sampler = EulerEDMSampler(
discretization=discretization,
guider=guider,
num_steps=num_steps,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=device,
)
samplers.append(sampler)
return samplers
def get_value_dict(
curr_imgs,
curr_imgs_clip,
curr_input_frame_indices,
curr_c2ws,
curr_Ks,
curr_input_camera_indices,
all_c2ws,
camera_scale=2.0,
):
assert sorted(curr_input_camera_indices) == sorted(
range(len(curr_input_camera_indices))
)
H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8
value_dict = {}
value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices]
value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
value_dict["cond_frames_mask"][curr_input_frame_indices] = True
value_dict["cond_aug"] = 0.0
c2w = to_hom_pose(curr_c2ws.float())
w2c = torch.linalg.inv(c2w)
# camera centering
ref_c2ws = all_c2ws
camera_dist_2med = torch.norm(
ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
dim=-1,
)
valid_mask = camera_dist_2med <= torch.clamp(
torch.quantile(camera_dist_2med, 0.97) * 10,
max=1e6,
)
c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
w2c = torch.linalg.inv(c2w)
# camera normalization
camera_dists = c2w[:, :3, 3].clone()
translation_scaling_factor = (
camera_scale
if torch.isclose(
torch.norm(camera_dists[0]),
torch.zeros(1),
atol=1e-5,
).any()
else (camera_scale / torch.norm(camera_dists[0]))
)
w2c[:, :3, 3] *= translation_scaling_factor
c2w[:, :3, 3] *= translation_scaling_factor
value_dict["plucker_coordinate"], _ = get_plucker_coordinates(
extrinsics_src=w2c[0],
extrinsics=w2c,
intrinsics=curr_Ks.float().clone(),
mode="plucker",
rel_zero_translation=True,
target_size=(H // F, W // F),
return_grid_cam=True,
)
value_dict["c2w"] = c2w
value_dict["K"] = curr_Ks
value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
value_dict["camera_mask"][curr_input_camera_indices] = True
return value_dict
def do_sample(
model,
ae,
conditioner,
denoiser,
sampler,
value_dict,
H,
W,
C,
F,
T,
cfg,
encoding_t=1,
decoding_t=1,
verbose=True,
global_pbar=None,
**_,
):
imgs = value_dict["cond_frames"].to("cuda")
input_masks = value_dict["cond_frames_mask"].to("cuda")
pluckers = value_dict["plucker_coordinate"].to("cuda")
num_samples = [1, T]
with torch.inference_mode(), torch.autocast("cuda"):
load_model(ae)
load_model(conditioner)
latents = torch.nn.functional.pad(
ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0
)
c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T)
uc_crossattn = torch.zeros_like(c_crossattn)
c_replace = latents.new_zeros(T, *latents.shape[1:])
c_replace[input_masks] = latents
uc_replace = torch.zeros_like(c_replace)
c_concat = torch.cat(
[
repeat(
input_masks,
"n -> n 1 h w",
h=pluckers.shape[2],
w=pluckers.shape[3],
),
pluckers,
],
1,
)
uc_concat = torch.cat(
[pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1
)
c_dense_vector = pluckers
uc_dense_vector = c_dense_vector
# TODO(hangg): concat and dense are problematic.
c = {
"crossattn": c_crossattn,
"replace": c_replace,
"concat": c_concat,
"dense_vector": c_dense_vector,
}
uc = {
"crossattn": uc_crossattn,
"replace": uc_replace,
"concat": uc_concat,
"dense_vector": uc_dense_vector,
}
unload_model(ae)
unload_model(conditioner)
additional_model_inputs = {"num_frames": T}
additional_sampler_inputs = {
"c2w": value_dict["c2w"].to("cuda"),
"K": value_dict["K"].to("cuda"),
"input_frame_mask": value_dict["cond_frames_mask"].to("cuda"),
}
if global_pbar is not None:
additional_sampler_inputs["global_pbar"] = global_pbar
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
load_model(model)
samples_z = sampler(
lambda input, sigma, c: denoiser(
model,
input,
sigma,
c,
**additional_model_inputs,
),
randn,
scale=cfg,
cond=c,
uc=uc,
verbose=verbose,
**additional_sampler_inputs,
)
if samples_z is None:
return
unload_model(model)
load_model(ae)
samples = ae.decode(samples_z, decoding_t)
unload_model(ae)
return samples
def run_one_scene(
task,
version_dict,
model,
ae,
conditioner,
denoiser,
image_cond,
camera_cond,
save_path,
use_traj_prior,
traj_prior_Ks,
traj_prior_c2ws,
seed=23,
gradio=False,
abort_event=None,
first_pass_pbar=None,
second_pass_pbar=None,
):
H, W, T, C, F, options = (
version_dict["H"],
version_dict["W"],
version_dict["T"],
version_dict["C"],
version_dict["f"],
version_dict["options"],
)
if isinstance(image_cond, str):
image_cond = {"img": [image_cond]}
imgs_clip, imgs, img_size = [], [], None
for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])):
if isinstance(img, str) or img is None:
img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore
img_size = img.shape[-2:]
if options.get("L_short", -1) == -1:
img, K = transform_img_and_K(
img,
(W, H),
K=K[None],
mode=(
options.get("transform_input", "crop")
if i in image_cond["input_indices"]
else options.get("transform_target", "crop")
),
scale=(
1.0
if i in image_cond["input_indices"]
else options.get("transform_scale", 1.0)
),
)
else:
downsample = 3
assert options["L_short"] % F * 2**downsample == 0, (
"Short side of the image should be divisible by "
f"F*2**{downsample}={F * 2**downsample}."
)
img, K = transform_img_and_K(
img,
options["L_short"],
K=K[None],
size_stride=F * 2**downsample,
mode=(
options.get("transform_input", "crop")
if i in image_cond["input_indices"]
else options.get("transform_target", "crop")
),
scale=(
1.0
if i in image_cond["input_indices"]
else options.get("transform_scale", 1.0)
),
)
version_dict["W"] = W = img.shape[-1]
version_dict["H"] = H = img.shape[-2]
K = K[0]
K[0] /= W
K[1] /= H
camera_cond["K"][i] = K
img_clip = img
elif isinstance(img, np.ndarray):
img_size = torch.Size(img.shape[:2])
img = torch.as_tensor(img).permute(2, 0, 1)
img = img.unsqueeze(0)
img = img / 255.0 * 2.0 - 1.0
if not gradio:
img, K = transform_img_and_K(img, (W, H), K=K[None])
assert K is not None
K = K[0]
K[0] /= W
K[1] /= H
camera_cond["K"][i] = K
img_clip = img
else:
assert (
False
), f"Variable `img` got {type(img)} type which is not supported!!!"
imgs_clip.append(img_clip)
imgs.append(img)
imgs_clip = torch.cat(imgs_clip, dim=0)
imgs = torch.cat(imgs, dim=0)
if traj_prior_Ks is not None:
assert img_size is not None
for i, prior_k in enumerate(traj_prior_Ks):
img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore
img, prior_k = transform_img_and_K(
img,
(W, H),
K=prior_k[None],
mode=options.get(
"transform_target", "crop"
), # mode for prior is always same as target
scale=options.get(
"transform_scale", 1.0
), # scale for prior is always same as target
)
prior_k = prior_k[0]
prior_k[0] /= W
prior_k[1] /= H
traj_prior_Ks[i] = prior_k
options["num_frames"] = T
discretization = denoiser.discretization
torch.cuda.empty_cache()
seed_everything(seed)
# Get Data
input_indices = image_cond["input_indices"]
input_imgs = imgs[input_indices]
input_imgs_clip = imgs_clip[input_indices]
input_c2ws = camera_cond["c2w"][input_indices]
input_Ks = camera_cond["K"][input_indices]
test_indices = [i for i in range(len(imgs)) if i not in input_indices]
test_imgs = imgs[test_indices]
test_imgs_clip = imgs_clip[test_indices]
test_c2ws = camera_cond["c2w"][test_indices]
test_Ks = camera_cond["K"][test_indices]
if options.get("save_input", True):
save_output(
{"/image": input_imgs},
save_path=os.path.join(save_path, "input"),
video_save_fps=2,
)
if not use_traj_prior:
chunk_strategy = options.get("chunk_strategy", "gt")
(
_,
input_inds_per_chunk,
input_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
) = chunk_input_and_test(
T,
input_c2ws,
test_c2ws,
input_indices,
test_indices,
options=options,
task=task,
chunk_strategy=chunk_strategy,
gt_input_inds=list(range(input_c2ws.shape[0])),
)
print(
f"One pass - chunking with `{chunk_strategy}` strategy: total "
f"{len(input_inds_per_chunk)} forward(s) ..."
)
all_samples = {}
all_test_inds = []
for i, (
chunk_input_inds,
chunk_input_sels,
chunk_test_inds,
chunk_test_sels,
) in tqdm(
enumerate(
zip(
input_inds_per_chunk,
input_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
)
),
total=len(input_inds_per_chunk),
leave=False,
):
(
curr_input_sels,
curr_test_sels,
curr_input_maps,
curr_test_maps,
) = pad_indices(
chunk_input_sels,
chunk_test_sels,
T=T,
padding_mode=options.get("t_padding_mode", "last"),
)
curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
assemble(
input=x[chunk_input_inds],
test=y[chunk_test_inds],
input_maps=curr_input_maps,
test_maps=curr_test_maps,
)
for x, y in zip(
[
torch.cat(
[
input_imgs,
get_k_from_dict(all_samples, "samples-rgb").to(
input_imgs.device
),
],
dim=0,
),
torch.cat(
[
input_imgs_clip,
get_k_from_dict(all_samples, "samples-rgb").to(
input_imgs.device
),
],
dim=0,
),
torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0),
torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0),
], # procedually append generated prior views to the input views
[test_imgs, test_imgs_clip, test_c2ws, test_Ks],
)
]
value_dict = get_value_dict(
curr_imgs.to("cuda"),
curr_imgs_clip.to("cuda"),
curr_input_sels
+ [
sel
for (ind, sel) in zip(
np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
curr_test_sels,
)
if test_indices[ind] in image_cond["input_indices"]
],
curr_c2ws,
curr_Ks,
curr_input_sels
+ [
sel
for (ind, sel) in zip(
np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
curr_test_sels,
)
if test_indices[ind] in camera_cond["input_indices"]
],
all_c2ws=camera_cond["c2w"],
)
samplers = create_samplers(
options["guider_types"],
discretization,
[len(curr_imgs)],
options["num_steps"],
options["cfg_min"],
abort_event=abort_event,
)
assert len(samplers) == 1
samples = do_sample(
model,
ae,
conditioner,
denoiser,
samplers[0],
value_dict,
H,
W,
C,
F,
T=len(curr_imgs),
cfg=(
options["cfg"][0]
if isinstance(options["cfg"], (list, tuple))
else options["cfg"]
),
**{k: options[k] for k in options if k not in ["cfg", "T"]},
)
samples = decode_output(
samples, len(curr_imgs), chunk_test_sels
) # decode into dict
if options.get("save_first_pass", False):
save_output(
replace_or_include_input_for_dict(
samples,
chunk_test_sels,
curr_imgs,
curr_c2ws,
curr_Ks,
),
save_path=os.path.join(save_path, "first-pass", f"forward_{i}"),
video_save_fps=2,
)
extend_dict(all_samples, samples)
all_test_inds.extend(chunk_test_inds)
else:
assert traj_prior_c2ws is not None, (
"`traj_prior_c2ws` should be set when using 2-pass sampling. One "
"potential reason is that the amount of input frames is larger than "
"T. Set `num_prior_frames` manually to overwrite the infered stats."
)
traj_prior_c2ws = torch.as_tensor(
traj_prior_c2ws,
device=input_c2ws.device,
dtype=input_c2ws.dtype,
)
if traj_prior_Ks is None:
traj_prior_Ks = test_Ks[:1].repeat_interleave(
traj_prior_c2ws.shape[0], dim=0
)
traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:])
traj_prior_imgs_clip = imgs_clip.new_zeros(
traj_prior_c2ws.shape[0], *imgs_clip.shape[1:]
)
# ---------------------------------- first pass ----------------------------------
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
chunk_strategy_first_pass = options.get(
"chunk_strategy_first_pass", "gt-nearest"
)
(
_,
input_inds_per_chunk,
input_sels_per_chunk,
prior_inds_per_chunk,
prior_sels_per_chunk,
) = chunk_input_and_test(
T_first_pass,
input_c2ws,
traj_prior_c2ws,
input_indices,
image_cond["prior_indices"],
options=options,
task=task,
chunk_strategy=chunk_strategy_first_pass,
gt_input_inds=list(range(input_c2ws.shape[0])),
)
print(
f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total "
f"{len(input_inds_per_chunk)} forward(s) ..."
)
all_samples = {}
all_prior_inds = []
for i, (
chunk_input_inds,
chunk_input_sels,
chunk_prior_inds,
chunk_prior_sels,
) in tqdm(
enumerate(
zip(
input_inds_per_chunk,
input_sels_per_chunk,
prior_inds_per_chunk,
prior_sels_per_chunk,
)
),
total=len(input_inds_per_chunk),
leave=False,
):
(
curr_input_sels,
curr_prior_sels,
curr_input_maps,
curr_prior_maps,
) = pad_indices(
chunk_input_sels,
chunk_prior_sels,
T=T_first_pass,
padding_mode=options.get("t_padding_mode", "last"),
)
curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
assemble(
input=x[chunk_input_inds],
test=y[chunk_prior_inds],
input_maps=curr_input_maps,
test_maps=curr_prior_maps,
)
for x, y in zip(
[
torch.cat(
[
input_imgs,
get_k_from_dict(all_samples, "samples-rgb").to(
input_imgs.device
),
],
dim=0,
),
torch.cat(
[
input_imgs_clip,
get_k_from_dict(all_samples, "samples-rgb").to(
input_imgs.device
),
],
dim=0,
),
torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0),
torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0),
], # procedually append generated prior views to the input views
[
traj_prior_imgs,
traj_prior_imgs_clip,
traj_prior_c2ws,
traj_prior_Ks,
],
)
]
value_dict = get_value_dict(
curr_imgs.to("cuda"),
curr_imgs_clip.to("cuda"),
curr_input_sels,
curr_c2ws,
curr_Ks,
list(range(T_first_pass)),
all_c2ws=camera_cond["c2w"], # traj_prior_c2ws,
)
samplers = create_samplers(
options["guider_types"],
discretization,
[T_first_pass, T_second_pass],
options["num_steps"],
options["cfg_min"],
abort_event=abort_event,
)
samples = do_sample(
model,
ae,
conditioner,
denoiser,
(
samplers[1]
if len(samplers) > 1
and options.get("ltr_first_pass", False)
and chunk_strategy_first_pass != "gt"
and i > 0
else samplers[0]
),
value_dict,
H,
W,
C,
F,
cfg=(
options["cfg"][0]
if isinstance(options["cfg"], (list, tuple))
else options["cfg"]
),
T=T_first_pass,
global_pbar=first_pass_pbar,
**{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
)
if samples is None:
return
samples = decode_output(
samples, T_first_pass, chunk_prior_sels
) # decode into dict
extend_dict(all_samples, samples)
all_prior_inds.extend(chunk_prior_inds)
if options.get("save_first_pass", True):
save_output(
all_samples,
save_path=os.path.join(save_path, "first-pass"),
video_save_fps=5,
)
video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4")
yield video_path_0
# ---------------------------------- second pass ----------------------------------
prior_indices = image_cond["prior_indices"]
assert (
prior_indices is not None
), "`prior_frame_indices` needs to be set if using 2-pass sampling."
prior_argsort = np.argsort(input_indices + prior_indices).tolist()
prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist()
gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])]
traj_prior_imgs = torch.cat(
[input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0
)[prior_argsort]
traj_prior_imgs_clip = torch.cat(
[
input_imgs_clip,
get_k_from_dict(all_samples, "samples-rgb"),
],
dim=0,
)[prior_argsort]
traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort]
traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort]
update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs)
update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws)
update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks)
chunk_strategy = options.get("chunk_strategy", "nearest")
(
_,
prior_inds_per_chunk,
prior_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
) = chunk_input_and_test(
T_second_pass,
traj_prior_c2ws,
test_c2ws,
prior_indices,
test_indices,
options=options,
task=task,
chunk_strategy=chunk_strategy,
gt_input_inds=gt_input_inds,
)
print(
f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total "
f"{len(prior_inds_per_chunk)} forward(s) ..."
)
all_samples = {}
all_test_inds = []
for i, (
chunk_prior_inds,
chunk_prior_sels,
chunk_test_inds,
chunk_test_sels,
) in tqdm(
enumerate(
zip(
prior_inds_per_chunk,
prior_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
)
),
total=len(prior_inds_per_chunk),
leave=False,
):
(
curr_prior_sels,
curr_test_sels,
curr_prior_maps,
curr_test_maps,
) = pad_indices(
chunk_prior_sels,
chunk_test_sels,
T=T_second_pass,
padding_mode="last",
)
curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
assemble(
input=x[chunk_prior_inds],
test=y[chunk_test_inds],
input_maps=curr_prior_maps,
test_maps=curr_test_maps,
)
for x, y in zip(
[
traj_prior_imgs,
traj_prior_imgs_clip,
traj_prior_c2ws,
traj_prior_Ks,
],
[test_imgs, test_imgs_clip, test_c2ws, test_Ks],
)
]
value_dict = get_value_dict(
curr_imgs.to("cuda"),
curr_imgs_clip.to("cuda"),
curr_prior_sels,
curr_c2ws,
curr_Ks,
list(range(T_second_pass)),
all_c2ws=camera_cond["c2w"], # test_c2ws,
)
samples = do_sample(
model,
ae,
conditioner,
denoiser,
samplers[1] if len(samplers) > 1 else samplers[0],
value_dict,
H,
W,
C,
F,
T=T_second_pass,
cfg=(
options["cfg"][1]
if isinstance(options["cfg"], (list, tuple))
and len(options["cfg"]) > 1
else options["cfg"]
),
global_pbar=second_pass_pbar,
**{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
)
if samples is None:
return
samples = decode_output(
samples, T_second_pass, chunk_test_sels
) # decode into dict
if options.get("save_second_pass", False):
save_output(
replace_or_include_input_for_dict(
samples,
chunk_test_sels,
curr_imgs,
curr_c2ws,
curr_Ks,
),
save_path=os.path.join(save_path, "second-pass", f"forward_{i}"),
video_save_fps=2,
)
extend_dict(all_samples, samples)
all_test_inds.extend(chunk_test_inds)
all_samples = {
key: value[np.argsort(all_test_inds)] for key, value in all_samples.items()
}
save_output(
replace_or_include_input_for_dict(
all_samples,
test_indices,
imgs.clone(),
camera_cond["c2w"].clone(),
camera_cond["K"].clone(),
)
if options.get("replace_or_include_input", False)
else all_samples,
save_path=save_path,
video_save_fps=options.get("video_save_fps", 2),
)
video_path_1 = os.path.join(save_path, "samples-rgb.mp4")
yield video_path_1