import collections import json import math import os import re import threading from typing import List, Literal, Optional, Tuple, Union import gradio as gr from colorama import Fore, Style, init init(autoreset=True) import imageio.v3 as iio import numpy as np import torch import torch.nn.functional as F import torchvision.transforms.functional as TF from einops import repeat from PIL import Image from tqdm.auto import tqdm from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose from seva.sampling import ( EulerEDMSampler, MultiviewCFG, MultiviewTemporalCFG, VanillaCFG, ) from seva.utils import seed_everything try: # Check if version string contains 'dev' or 'nightly' version = torch.__version__ IS_TORCH_NIGHTLY = "dev" in version if IS_TORCH_NIGHTLY: torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment] torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment] torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment] except Exception: IS_TORCH_NIGHTLY = False def pad_indices( input_indices: List[int], test_indices: List[int], T: int, padding_mode: Literal["first", "last", "none"] = "last", ): assert padding_mode in ["last", "none"], "`first` padding is not supported yet." if padding_mode == "last": padded_indices = [ i for i in range(T) if i not in (input_indices + test_indices) ] else: padded_indices = [] input_selects = list(range(len(input_indices))) test_selects = list(range(len(test_indices))) if max(input_indices) > max(test_indices): # last elem from input input_selects += [input_selects[-1]] * len(padded_indices) input_indices = input_indices + padded_indices sorted_inds = np.argsort(input_indices) input_indices = [input_indices[ind] for ind in sorted_inds] input_selects = [input_selects[ind] for ind in sorted_inds] else: # last elem from test test_selects += [test_selects[-1]] * len(padded_indices) test_indices = test_indices + padded_indices sorted_inds = np.argsort(test_indices) test_indices = [test_indices[ind] for ind in sorted_inds] test_selects = [test_selects[ind] for ind in sorted_inds] if padding_mode == "last": input_maps = np.array([-1] * T) test_maps = np.array([-1] * T) else: input_maps = np.array([-1] * (len(input_indices) + len(test_indices))) test_maps = np.array([-1] * (len(input_indices) + len(test_indices))) input_maps[input_indices] = input_selects test_maps[test_indices] = test_selects return input_indices, test_indices, input_maps, test_maps def assemble( input, test, input_maps, test_maps, ): T = len(input_maps) assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0) assembled[input_maps != -1] = input[input_maps[input_maps != -1]] assembled[test_maps != -1] = test[test_maps[test_maps != -1]] assert np.logical_xor(input_maps != -1, test_maps != -1).all() return assembled def get_resizing_factor( target_shape: Tuple[int, int], # H, W current_shape: Tuple[int, int], # H, W cover_target: bool = True, # If True, the output shape will fully cover the target shape. # If No, the target shape will fully cover the output shape. ) -> float: r_bound = target_shape[1] / target_shape[0] aspect_r = current_shape[1] / current_shape[0] if r_bound >= 1.0: if cover_target: if aspect_r >= r_bound: factor = min(target_shape) / min(current_shape) elif aspect_r < 1.0: factor = max(target_shape) / min(current_shape) else: factor = max(target_shape) / max(current_shape) else: if aspect_r >= r_bound: factor = max(target_shape) / max(current_shape) elif aspect_r < 1.0: factor = min(target_shape) / max(current_shape) else: factor = min(target_shape) / min(current_shape) else: if cover_target: if aspect_r <= r_bound: factor = min(target_shape) / min(current_shape) elif aspect_r > 1.0: factor = max(target_shape) / min(current_shape) else: factor = max(target_shape) / max(current_shape) else: if aspect_r <= r_bound: factor = max(target_shape) / max(current_shape) elif aspect_r > 1.0: factor = min(target_shape) / max(current_shape) else: factor = min(target_shape) / min(current_shape) return factor def get_unique_embedder_keys_from_conditioner(conditioner): keys = [x.input_key for x in conditioner.embedders if x.input_key is not None] keys = [item for sublist in keys for item in sublist] # Flatten list return set(keys) def get_wh_with_fixed_shortest_side(w, h, size): # size is smaller or equal to zero, we return original w h if size is None or size <= 0: return w, h if w < h: new_w = size new_h = int(size * h / w) else: new_h = size new_w = int(size * w / h) return new_w, new_h def load_img_and_K( image_path_or_size: Union[str, torch.Size], size: Optional[Union[int, Tuple[int, int]]], scale: float = 1.0, center: Tuple[float, float] = (0.5, 0.5), K: torch.Tensor | None = None, size_stride: int = 1, center_crop: bool = False, image_as_tensor: bool = True, context_rgb: np.ndarray | None = None, device: str = "cuda", ): if isinstance(image_path_or_size, torch.Size): image = Image.new("RGBA", image_path_or_size[::-1]) else: image = Image.open(image_path_or_size).convert("RGBA") w, h = image.size if size is None: size = (w, h) image = np.array(image).astype(np.float32) / 255 if image.shape[-1] == 4: rgb, alpha = image[:, :, :3], image[:, :, 3:] if context_rgb is not None: image = rgb * alpha + context_rgb * (1 - alpha) else: image = rgb * alpha + (1 - alpha) image = image.transpose(2, 0, 1) image = torch.from_numpy(image).to(dtype=torch.float32) image = image.unsqueeze(0) if isinstance(size, (tuple, list)): # => if size is a tuple or list, we first rescale to fully cover the `size` # area and then crop the `size` area from the rescale image W, H = size else: # => if size is int, we rescale the image to fit the shortest side to size # => if size is None, no rescaling is applied W, H = get_wh_with_fixed_shortest_side(w, h, size) W, H = ( math.floor(W / size_stride + 0.5) * size_stride, math.floor(H / size_stride + 0.5) * size_stride, ) rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w)) resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)] image = torch.nn.functional.interpolate( image, resize_size, mode="area", antialias=False ) if scale < 1.0: pw = math.ceil((W - resize_size[1]) * 0.5) ph = math.ceil((H - resize_size[0]) * 0.5) image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0) cy_center = int(center[1] * image.shape[-2]) cx_center = int(center[0] * image.shape[-1]) if center_crop: side = min(H, W) ct = max(0, cy_center - side // 2) cl = max(0, cx_center - side // 2) ct = min(ct, image.shape[-2] - side) cl = min(cl, image.shape[-1] - side) image = TF.crop(image, top=ct, left=cl, height=side, width=side) else: ct = max(0, cy_center - H // 2) cl = max(0, cx_center - W // 2) ct = min(ct, image.shape[-2] - H) cl = min(cl, image.shape[-1] - W) image = TF.crop(image, top=ct, left=cl, height=H, width=W) if K is not None: K = K.clone() if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1): K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K else: K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K K[:2, 2] -= K.new_tensor([cl, ct]) if image_as_tensor: # tensor of shape (1, 3, H, W) with values ranging from (-1, 1) image = image.to(device) * 2.0 - 1.0 else: # PIL Image with values ranging from (0, 255) image = image.permute(0, 2, 3, 1).numpy()[0] image = Image.fromarray((image * 255).astype(np.uint8)) return image, K def transform_img_and_K( image: torch.Tensor, size: Union[int, Tuple[int, int]], scale: float = 1.0, center: Tuple[float, float] = (0.5, 0.5), K: torch.Tensor | None = None, size_stride: int = 1, mode: str = "crop", ): assert mode in [ "crop", "pad", "stretch", ], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}" h, w = image.shape[-2:] if isinstance(size, (tuple, list)): # => if size is a tuple or list, we first rescale to fully cover the `size` # area and then crop the `size` area from the rescale image W, H = size else: # => if size is int, we rescale the image to fit the shortest side to size # => if size is None, no rescaling is applied W, H = get_wh_with_fixed_shortest_side(w, h, size) W, H = ( math.floor(W / size_stride + 0.5) * size_stride, math.floor(H / size_stride + 0.5) * size_stride, ) if mode == "stretch": rh, rw = H, W else: rfs = get_resizing_factor( (H, W), (h, w), cover_target=mode != "pad", ) (rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)] rh, rw = int(rh / scale), int(rw / scale) image = torch.nn.functional.interpolate( image, (rh, rw), mode="area", antialias=False ) cy_center = int(center[1] * image.shape[-2]) cx_center = int(center[0] * image.shape[-1]) if mode != "pad": ct = max(0, cy_center - H // 2) cl = max(0, cx_center - W // 2) ct = min(ct, image.shape[-2] - H) cl = min(cl, image.shape[-1] - W) image = TF.crop(image, top=ct, left=cl, height=H, width=W) pl, pt = 0, 0 else: pt = max(0, H // 2 - cy_center) pl = max(0, W // 2 - cx_center) pb = max(0, H - pt - image.shape[-2]) pr = max(0, W - pl - image.shape[-1]) image = TF.pad( image, [pl, pt, pr, pb], ) cl, ct = 0, 0 if K is not None: K = K.clone() # K[:, :2, 2] += K.new_tensor([pl, pt]) if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1): K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K else: K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct]) return image, K lowvram_mode = False def set_lowvram_mode(mode): global lowvram_mode lowvram_mode = mode def load_model(model, device: str = "cuda"): model.to(device) def unload_model(model): global lowvram_mode if lowvram_mode: model.cpu() torch.cuda.empty_cache() def infer_prior_stats( T, num_input_frames, num_total_frames, version_dict, ): options = version_dict["options"] chunk_strategy = options.get("chunk_strategy", "nearest") T_first_pass = T[0] if isinstance(T, (list, tuple)) else T T_second_pass = T[1] if isinstance(T, (list, tuple)) else T # get traj_prior_c2ws for 2-pass sampling if chunk_strategy.startswith("interp"): # Start and end have alreay taken up two slots # +1 means we need X + 1 prior frames to bound X times forwards for all test frames # Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the # generated frames due to insufficient prior frames. This option is effective for # complicated trajectory and when `interp` strategy is used (usually semi-dense-view # regime). Recommended range is [1.0 (default), 1.5]. if num_input_frames >= options.get("num_input_semi_dense", 9): num_prior_frames = ( math.ceil( num_total_frames / (T_second_pass - 2) * options.get("num_prior_frames_ratio", 1.0) ) + 1 ) if num_prior_frames + num_input_frames < T_first_pass: num_prior_frames = T_first_pass - num_input_frames num_prior_frames = max( num_prior_frames, options.get("num_prior_frames", 0), ) T_first_pass = num_prior_frames + num_input_frames if "gt" in chunk_strategy: T_second_pass = T_second_pass + num_input_frames # Dynamically update context window length. version_dict["T"] = [T_first_pass, T_second_pass] else: num_prior_frames = ( math.ceil( num_total_frames / ( T_second_pass - 2 - (num_input_frames if "gt" in chunk_strategy else 0) ) * options.get("num_prior_frames_ratio", 1.0) ) + 1 ) if num_prior_frames + num_input_frames < T_first_pass: num_prior_frames = T_first_pass - num_input_frames num_prior_frames = max( num_prior_frames, options.get("num_prior_frames", 0), ) else: num_prior_frames = max( T_first_pass - num_input_frames, options.get("num_prior_frames", 0), ) if num_input_frames >= options.get("num_input_semi_dense", 9): T_first_pass = num_prior_frames + num_input_frames # Dynamically update context window length. version_dict["T"] = [T_first_pass, T_second_pass] return num_prior_frames def infer_prior_inds( c2ws, num_prior_frames, input_frame_indices, options, ): chunk_strategy = options.get("chunk_strategy", "nearest") if chunk_strategy.startswith("interp"): prior_frame_indices = np.array( [i for i in range(c2ws.shape[0]) if i not in input_frame_indices] ) prior_frame_indices = prior_frame_indices[ np.ceil( np.linspace( 0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True ) ).astype(int) ] # having a ceil here is actually safer for corner case else: prior_frame_indices = [] while len(prior_frame_indices) < num_prior_frames: closest_distance = np.abs( np.arange(c2ws.shape[0])[None] - np.concatenate( [np.array(input_frame_indices), np.array(prior_frame_indices)] )[:, None] ).min(0) prior_frame_indices.append(np.argsort(closest_distance)[-1]) return np.sort(prior_frame_indices) def compute_relative_inds( source_inds, target_inds, ): assert len(source_inds) > 2 # compute relative indices of target_inds within source_inds relative_inds = [] for ind in target_inds: if ind in source_inds: relative_ind = int(np.where(source_inds == ind)[0][0]) elif ind < source_inds[0]: # extrapolate relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0])) elif ind > source_inds[-1]: # extrapolate relative_ind = len(source_inds) + ( (ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2]) ) else: # interpolate lower_inds = source_inds[source_inds < ind] upper_inds = source_inds[source_inds > ind] if len(lower_inds) > 0 and len(upper_inds) > 0: lower_ind = lower_inds[-1] upper_ind = upper_inds[0] relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0]) relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0]) relative_ind = relative_lower_ind + (ind - lower_ind) / ( upper_ind - lower_ind ) * (relative_upper_ind - relative_lower_ind) else: # Out of range relative_inds.append(float("nan")) # Or some other placeholder relative_inds.append(relative_ind) return relative_inds def find_nearest_source_inds( source_c2ws, target_c2ws, nearest_num=1, mode="translation", ): dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy() sorted_inds = np.argsort(dists, axis=0).T return sorted_inds[:, :nearest_num] def chunk_input_and_test( T, input_c2ws, test_c2ws, input_ords, # orders test_ords, # orders options, task: str = "img2img", chunk_strategy: str = "gt", gt_input_inds: list = [], ): M, N = input_c2ws.shape[0], test_c2ws.shape[0] chunks = [] if chunk_strategy.startswith("gt"): assert len(gt_input_inds) < T, ( f"Number of gt input frames {len(gt_input_inds)} should be " f"less than {T} when `gt` chunking strategy is used." ) assert ( list(range(M)) == gt_input_inds ), "All input_c2ws should be gt when `gt` chunking strategy is used." # LEGACY CHUNKING STRATEGY # num_test_per_chunk = T - len(gt_input_inds) # test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds] # for i in range(0, test_c2ws.shape[0], num_test_per_chunk): # chunk = ["NULL"] * T # for j, k in enumerate(gt_input_inds): # chunk[k] = f"!{j:03d}" # for j, k in enumerate( # test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]] # ): # chunk[k] = f">{i + j:03d}" # chunks.append(chunk) num_test_seen = 0 while num_test_seen < N: chunk = [f"!{i:03d}" for i in gt_input_inds] if chunk_strategy != "gt" and num_test_seen > 0: pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33) if (N - num_test_seen) >= math.floor( (T - len(gt_input_inds)) * pseudo_num_ratio ): pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio) else: pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen) pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000)) if "ltr" in chunk_strategy: chunk.extend( [ f"!{i + len(gt_input_inds):03d}" for i in range(num_test_seen - pseudo_num, num_test_seen) ] ) elif "nearest" in chunk_strategy: source_inds = np.concatenate( [ find_nearest_source_inds( test_c2ws[:num_test_seen], test_c2ws[num_test_seen:], nearest_num=1, # pseudo_num, mode="rotation", ), find_nearest_source_inds( test_c2ws[:num_test_seen], test_c2ws[num_test_seen:], nearest_num=1, # pseudo_num, mode="translation", ), ], axis=1, ) ####### [HACK ALERT] keep running until pseudo num is stablized ######## temp_pseudo_num = pseudo_num while True: nearest_source_inds = np.concatenate( [ np.sort( [ ind for (ind, _) in collections.Counter( [ item for item in source_inds[ : T - len(gt_input_inds) - temp_pseudo_num ] .flatten() .tolist() if item != ( num_test_seen - 1 ) # exclude the last one here ] ).most_common(pseudo_num - 1) ], ).astype(int), [num_test_seen - 1], # always keep the last one ] ) if len(nearest_source_inds) >= temp_pseudo_num: break # stablized else: temp_pseudo_num = len(nearest_source_inds) pseudo_num = len(nearest_source_inds) ######################################################################## chunk.extend( [f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds] ) else: raise NotImplementedError( f"Chunking strategy {chunk_strategy} for the first pass is not implemented." ) chunk.extend( [ f">{i:03d}" for i in range( num_test_seen, min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N), ) ] ) else: chunk.extend( [ f">{i:03d}" for i in range( num_test_seen, min(num_test_seen + T - len(gt_input_inds), N), ) ] ) num_test_seen += sum([1 for c in chunk if c.startswith(">")]) if len(chunk) < T: chunk.extend(["NULL"] * (T - len(chunk))) chunks.append(chunk) elif chunk_strategy.startswith("nearest"): input_imgs = np.array([f"!{i:03d}" for i in range(M)]) test_imgs = np.array([f">{i:03d}" for i in range(N)]) match = re.match(r"^nearest-(\d+)$", chunk_strategy) if match: nearest_num = int(match.group(1)) assert ( nearest_num < T ), f"Nearest number of {nearest_num} should be less than {T}." source_inds = find_nearest_source_inds( input_c2ws, test_c2ws, nearest_num=nearest_num, mode="translation", # during the second pass, consider translation only is enough ) for i in range(0, N, T - nearest_num): nearest_source_inds = np.sort( [ ind for (ind, _) in collections.Counter( source_inds[i : i + T - nearest_num].flatten().tolist() ).most_common(nearest_num) ] ) chunk = ( input_imgs[nearest_source_inds].tolist() + test_imgs[i : i + T - nearest_num].tolist() ) chunks.append(chunk + ["NULL"] * (T - len(chunk))) else: # do not always condition on gt cond frames if "gt" not in chunk_strategy: gt_input_inds = [] source_inds = find_nearest_source_inds( input_c2ws, test_c2ws, nearest_num=1, mode="translation", # during the second pass, consider translation only is enough )[:, 0] test_inds_per_input = {} for test_idx, input_idx in enumerate(source_inds): if input_idx not in test_inds_per_input: test_inds_per_input[input_idx] = [] test_inds_per_input[input_idx].append(test_idx) num_test_seen = 0 chunk = input_imgs[gt_input_inds].tolist() candidate_input_inds = sorted(list(test_inds_per_input.keys())) while num_test_seen < N: input_idx = candidate_input_inds[0] test_inds = test_inds_per_input[input_idx] input_is_cond = input_idx in gt_input_inds prefix_inds = [] if input_is_cond else [input_idx] if len(chunk) == T - len(prefix_inds) or not candidate_input_inds: if chunk: chunk += ["NULL"] * (T - len(chunk)) chunks.append(chunk) chunk = input_imgs[gt_input_inds].tolist() if num_test_seen >= N: break continue candidate_chunk = ( input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist() ) space_left = T - len(chunk) if len(candidate_chunk) <= space_left: chunk.extend(candidate_chunk) num_test_seen += len(test_inds) candidate_input_inds.pop(0) else: chunk.extend(candidate_chunk[:space_left]) num_input_idx = 0 if input_is_cond else 1 num_test_seen += space_left - num_input_idx test_inds_per_input[input_idx] = test_inds[ space_left - num_input_idx : ] if len(chunk) == T: chunks.append(chunk) chunk = input_imgs[gt_input_inds].tolist() if chunk and chunk != input_imgs[gt_input_inds].tolist(): chunks.append(chunk + ["NULL"] * (T - len(chunk))) elif chunk_strategy.startswith("interp"): # `interp` chunk requires ordering info assert input_ords is not None and test_ords is not None, ( "When using `interp` chunking strategy, ordering of input " "and test frames should be provided." ) # if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not # use input views since their order info within target views is unknown if "img2trajvid" in task: assert ( list(range(len(gt_input_inds))) == gt_input_inds ), "`img2trajvid` task should put `gt_input_inds` in start." input_c2ws = input_c2ws[ [ind for ind in range(M) if ind not in gt_input_inds] ] input_ords = [ input_ords[ind] for ind in range(M) if ind not in gt_input_inds ] M = input_c2ws.shape[0] input_ords = [0] + input_ords # this is a hack accounting for test views # before the first input view input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included # in the last forward when input_ords[-1] == test_ords[-1] input_ords = np.array(input_ords)[:, None] input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)]) test_ords = np.array(test_ords)[None] in_stop_ranges = np.logical_and( np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0), np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0), ) # (M, N) assert (in_stop_ranges.sum(1) <= T - 2).all(), ( "More input frames need to be sampled during the first pass to ensure " f"#test frames during each forard in the second pass will not exceed {T - 2}." ) if input_ords[1, 0] <= test_ords[0, 0]: assert not in_stop_ranges[0].any() if input_ords[-1, 0] >= test_ords[0, -1]: assert not in_stop_ranges[-1].any() gt_chunk = ( [f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else [] ) chunk = gt_chunk + [] # any test views before the first input views if in_stop_ranges[0].any(): for j, in_range in enumerate(in_stop_ranges[0]): if in_range: chunk.append(f">{j:03d}") in_stop_ranges = in_stop_ranges[1:] i = 0 base_i = len(gt_input_inds) if "img2trajvid" in task else 0 chunk.append(f"!{i + base_i:03d}") while i < len(in_stop_ranges): in_stop_range = in_stop_ranges[i] if not in_stop_range.any(): i += 1 continue input_left = i + 1 < M space_left = T - len(chunk) if sum(in_stop_range) + input_left <= space_left: for j, in_range in enumerate(in_stop_range): if in_range: chunk.append(f">{j:03d}") i += 1 if input_left: chunk.append(f"!{i + base_i:03d}") else: chunk += ["NULL"] * space_left chunks.append(chunk) chunk = gt_chunk + [f"!{i + base_i:03d}"] if len(chunk) > 1: chunk += ["NULL"] * (T - len(chunk)) chunks.append(chunk) else: raise NotImplementedError ( input_inds_per_chunk, input_sels_per_chunk, test_inds_per_chunk, test_sels_per_chunk, ) = ( [], [], [], [], ) for chunk in chunks: input_inds = [ int(img.removeprefix("!")) for img in chunk if img.startswith("!") ] input_sels = [chunk.index(img) for img in chunk if img.startswith("!")] test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")] test_sels = [chunk.index(img) for img in chunk if img.startswith(">")] input_inds_per_chunk.append(input_inds) input_sels_per_chunk.append(input_sels) test_inds_per_chunk.append(test_inds) test_sels_per_chunk.append(test_sels) if options.get("sampler_verbose", True): def colorize(item): if item.startswith("!"): return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!' elif item.startswith(">"): return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>' return item # Default color if neither '!' nor '>' print("\nchunks:") for chunk in chunks: print(", ".join(colorize(item) for item in chunk)) return ( chunks, input_inds_per_chunk, # ordering of input in raw sequence input_sels_per_chunk, # ordering of input in one-forward sequence of length T test_inds_per_chunk, # ordering of test in raw sequence test_sels_per_chunk, # oredering of test in one-forward sequence of length T ) def is_k_in_dict(d, k): return any(map(lambda x: x.startswith(k), d.keys())) def get_k_from_dict(d, k): media_d = {} for key, value in d.items(): if key == k: return value if key.startswith(k): media = key.split("/")[-1] if media == "raw": return value media_d[media] = value if len(media_d) == 0: return torch.tensor([]) assert ( len(media_d) == 1 ), f"multiple media found in {d} for key {k}: {media_d.keys()}" return media_d[media] def update_kv_for_dict(d, k, v): for key in d.keys(): if key.startswith(k): d[key] = v return d def extend_dict(ds, d): for key in d.keys(): if key in ds: ds[key] = torch.cat([ds[key], d[key]], 0) else: ds[key] = d[key] return ds def replace_or_include_input_for_dict( samples, test_indices, imgs, c2w, K, ): samples_new = {} for sample, value in samples.items(): if "rgb" in sample: imgs[test_indices] = ( value[test_indices] if value.shape[0] == imgs.shape[0] else value ).to(device=imgs.device, dtype=imgs.dtype) samples_new[sample] = imgs elif "c2w" in sample: c2w[test_indices] = ( value[test_indices] if value.shape[0] == c2w.shape[0] else value ).to(device=c2w.device, dtype=c2w.dtype) samples_new[sample] = c2w elif "intrinsics" in sample: K[test_indices] = ( value[test_indices] if value.shape[0] == K.shape[0] else value ).to(device=K.device, dtype=K.dtype) samples_new[sample] = K else: samples_new[sample] = value return samples_new def decode_output( samples, T, indices=None, ): # decode model output into dict if it is not if isinstance(samples, dict): # model with postprocessor and outputs dict for sample, value in samples.items(): if isinstance(value, torch.Tensor): value = value.detach().cpu() elif isinstance(value, np.ndarray): value = torch.from_numpy(value) else: value = torch.tensor(value) if indices is not None and value.shape[0] == T: value = value[indices] samples[sample] = value else: # model without postprocessor and outputs tensor (rgb) samples = samples.detach().cpu() if indices is not None and samples.shape[0] == T: samples = samples[indices] samples = {"samples-rgb/image": samples} return samples def save_output( samples, save_path, video_save_fps=2, ): os.makedirs(save_path, exist_ok=True) for sample in samples: media_type = "video" if "/" in sample: sample_, media_type = sample.split("/") else: sample_ = sample value = samples[sample] if isinstance(value, torch.Tensor): value = value.detach().cpu() elif isinstance(value, np.ndarray): value = torch.from_numpy(value) else: value = torch.tensor(value) if media_type == "image": value = (value.permute(0, 2, 3, 1) + 1) / 2.0 value = (value * 255).clamp(0, 255).to(torch.uint8) iio.imwrite( os.path.join(save_path, f"{sample_}.mp4") if sample_ else f"{save_path}.mp4", value, fps=video_save_fps, macro_block_size=1, ffmpeg_log_level="error", ) os.makedirs(os.path.join(save_path, sample_), exist_ok=True) for i, s in enumerate(value): iio.imwrite( os.path.join(save_path, sample_, f"{i:03d}.png"), s, ) elif media_type == "video": value = (value.permute(0, 2, 3, 1) + 1) / 2.0 value = (value * 255).clamp(0, 255).to(torch.uint8) iio.imwrite( os.path.join(save_path, f"{sample_}.mp4"), value, fps=video_save_fps, macro_block_size=1, ffmpeg_log_level="error", ) elif media_type == "raw": torch.save( value, os.path.join(save_path, f"{sample_}.pt"), ) else: pass def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks): import os.path as osp out_frames = [] for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks): out_frame = { "fl_x": K[0][0].item(), "fl_y": K[1][1].item(), "cx": K[0][2].item(), "cy": K[1][2].item(), "w": img_wh[0].item(), "h": img_wh[1].item(), "file_path": f"./{osp.relpath(img_path, start=save_path)}" if img_path is not None else None, "transform_matrix": c2w.tolist(), } out_frames.append(out_frame) out = { # "camera_model": "PINHOLE", "orientation_override": "none", "frames": out_frames, } with open(osp.join(save_path, "transforms.json"), "w") as of: json.dump(out, of, indent=5) class GradioTrackedSampler(EulerEDMSampler): """ A thin wrapper around the EulerEDMSampler that allows tracking progress and aborting sampling for gradio demo. """ def __init__(self, abort_event: threading.Event, *args, **kwargs): super().__init__(*args, **kwargs) self.abort_event = abort_event def __call__( # type: ignore self, denoiser, x: torch.Tensor, scale: float | torch.Tensor, cond: dict, uc: dict | None = None, num_steps: int | None = None, verbose: bool = True, global_pbar: gr.Progress | None = None, **guider_kwargs, ) -> torch.Tensor | None: uc = cond if uc is None else uc x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps, ) for i in self.get_sigma_gen(num_sigmas, verbose=verbose): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, scale, cond, uc, gamma, **guider_kwargs, ) # Allow tracking progress in gradio demo. if global_pbar is not None: global_pbar.update() # Allow aborting sampling in gradio demo. if self.abort_event.is_set(): return None return x def create_samplers( guider_types: int | list[int], discretization, num_frames: list[int] | None, num_steps: int, cfg_min: float = 1.0, device: str | torch.device = "cuda", abort_event: threading.Event | None = None, ): guider_mapping = { 0: VanillaCFG, 1: MultiviewCFG, 2: MultiviewTemporalCFG, } samplers = [] if not isinstance(guider_types, (list, tuple)): guider_types = [guider_types] for i, guider_type in enumerate(guider_types): if guider_type not in guider_mapping: raise ValueError( f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}" ) guider_cls = guider_mapping[guider_type] guider_args = () if guider_type > 0: guider_args += (cfg_min,) if guider_type == 2: assert num_frames is not None guider_args = (num_frames[i], cfg_min) guider = guider_cls(*guider_args) if abort_event is not None: sampler = GradioTrackedSampler( abort_event, discretization=discretization, guider=guider, num_steps=num_steps, s_churn=0.0, s_tmin=0.0, s_tmax=999.0, s_noise=1.0, verbose=True, device=device, ) else: sampler = EulerEDMSampler( discretization=discretization, guider=guider, num_steps=num_steps, s_churn=0.0, s_tmin=0.0, s_tmax=999.0, s_noise=1.0, verbose=True, device=device, ) samplers.append(sampler) return samplers def get_value_dict( curr_imgs, curr_imgs_clip, curr_input_frame_indices, curr_c2ws, curr_Ks, curr_input_camera_indices, all_c2ws, camera_scale=2.0, ): assert sorted(curr_input_camera_indices) == sorted( range(len(curr_input_camera_indices)) ) H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8 value_dict = {} value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices] value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs) value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool) value_dict["cond_frames_mask"][curr_input_frame_indices] = True value_dict["cond_aug"] = 0.0 c2w = to_hom_pose(curr_c2ws.float()) w2c = torch.linalg.inv(c2w) # camera centering ref_c2ws = all_c2ws camera_dist_2med = torch.norm( ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values, dim=-1, ) valid_mask = camera_dist_2med <= torch.clamp( torch.quantile(camera_dist_2med, 0.97) * 10, max=1e6, ) c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True) w2c = torch.linalg.inv(c2w) # camera normalization camera_dists = c2w[:, :3, 3].clone() translation_scaling_factor = ( camera_scale if torch.isclose( torch.norm(camera_dists[0]), torch.zeros(1), atol=1e-5, ).any() else (camera_scale / torch.norm(camera_dists[0])) ) w2c[:, :3, 3] *= translation_scaling_factor c2w[:, :3, 3] *= translation_scaling_factor value_dict["plucker_coordinate"], _ = get_plucker_coordinates( extrinsics_src=w2c[0], extrinsics=w2c, intrinsics=curr_Ks.float().clone(), mode="plucker", rel_zero_translation=True, target_size=(H // F, W // F), return_grid_cam=True, ) value_dict["c2w"] = c2w value_dict["K"] = curr_Ks value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool) value_dict["camera_mask"][curr_input_camera_indices] = True return value_dict def do_sample( model, ae, conditioner, denoiser, sampler, value_dict, H, W, C, F, T, cfg, encoding_t=1, decoding_t=1, verbose=True, global_pbar=None, **_, ): imgs = value_dict["cond_frames"].to("cuda") input_masks = value_dict["cond_frames_mask"].to("cuda") pluckers = value_dict["plucker_coordinate"].to("cuda") num_samples = [1, T] with torch.inference_mode(), torch.autocast("cuda"): load_model(ae) load_model(conditioner) latents = torch.nn.functional.pad( ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0 ) c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T) uc_crossattn = torch.zeros_like(c_crossattn) c_replace = latents.new_zeros(T, *latents.shape[1:]) c_replace[input_masks] = latents uc_replace = torch.zeros_like(c_replace) c_concat = torch.cat( [ repeat( input_masks, "n -> n 1 h w", h=pluckers.shape[2], w=pluckers.shape[3], ), pluckers, ], 1, ) uc_concat = torch.cat( [pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1 ) c_dense_vector = pluckers uc_dense_vector = c_dense_vector # TODO(hangg): concat and dense are problematic. c = { "crossattn": c_crossattn, "replace": c_replace, "concat": c_concat, "dense_vector": c_dense_vector, } uc = { "crossattn": uc_crossattn, "replace": uc_replace, "concat": uc_concat, "dense_vector": uc_dense_vector, } unload_model(ae) unload_model(conditioner) additional_model_inputs = {"num_frames": T} additional_sampler_inputs = { "c2w": value_dict["c2w"].to("cuda"), "K": value_dict["K"].to("cuda"), "input_frame_mask": value_dict["cond_frames_mask"].to("cuda"), } if global_pbar is not None: additional_sampler_inputs["global_pbar"] = global_pbar shape = (math.prod(num_samples), C, H // F, W // F) randn = torch.randn(shape).to("cuda") load_model(model) samples_z = sampler( lambda input, sigma, c: denoiser( model, input, sigma, c, **additional_model_inputs, ), randn, scale=cfg, cond=c, uc=uc, verbose=verbose, **additional_sampler_inputs, ) if samples_z is None: return unload_model(model) load_model(ae) samples = ae.decode(samples_z, decoding_t) unload_model(ae) return samples def run_one_scene( task, version_dict, model, ae, conditioner, denoiser, image_cond, camera_cond, save_path, use_traj_prior, traj_prior_Ks, traj_prior_c2ws, seed=23, gradio=False, abort_event=None, first_pass_pbar=None, second_pass_pbar=None, ): H, W, T, C, F, options = ( version_dict["H"], version_dict["W"], version_dict["T"], version_dict["C"], version_dict["f"], version_dict["options"], ) if isinstance(image_cond, str): image_cond = {"img": [image_cond]} imgs_clip, imgs, img_size = [], [], None for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])): if isinstance(img, str) or img is None: img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore img_size = img.shape[-2:] if options.get("L_short", -1) == -1: img, K = transform_img_and_K( img, (W, H), K=K[None], mode=( options.get("transform_input", "crop") if i in image_cond["input_indices"] else options.get("transform_target", "crop") ), scale=( 1.0 if i in image_cond["input_indices"] else options.get("transform_scale", 1.0) ), ) else: downsample = 3 assert options["L_short"] % F * 2**downsample == 0, ( "Short side of the image should be divisible by " f"F*2**{downsample}={F * 2**downsample}." ) img, K = transform_img_and_K( img, options["L_short"], K=K[None], size_stride=F * 2**downsample, mode=( options.get("transform_input", "crop") if i in image_cond["input_indices"] else options.get("transform_target", "crop") ), scale=( 1.0 if i in image_cond["input_indices"] else options.get("transform_scale", 1.0) ), ) version_dict["W"] = W = img.shape[-1] version_dict["H"] = H = img.shape[-2] K = K[0] K[0] /= W K[1] /= H camera_cond["K"][i] = K img_clip = img elif isinstance(img, np.ndarray): img_size = torch.Size(img.shape[:2]) img = torch.as_tensor(img).permute(2, 0, 1) img = img.unsqueeze(0) img = img / 255.0 * 2.0 - 1.0 if not gradio: img, K = transform_img_and_K(img, (W, H), K=K[None]) assert K is not None K = K[0] K[0] /= W K[1] /= H camera_cond["K"][i] = K img_clip = img else: assert ( False ), f"Variable `img` got {type(img)} type which is not supported!!!" imgs_clip.append(img_clip) imgs.append(img) imgs_clip = torch.cat(imgs_clip, dim=0) imgs = torch.cat(imgs, dim=0) if traj_prior_Ks is not None: assert img_size is not None for i, prior_k in enumerate(traj_prior_Ks): img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore img, prior_k = transform_img_and_K( img, (W, H), K=prior_k[None], mode=options.get( "transform_target", "crop" ), # mode for prior is always same as target scale=options.get( "transform_scale", 1.0 ), # scale for prior is always same as target ) prior_k = prior_k[0] prior_k[0] /= W prior_k[1] /= H traj_prior_Ks[i] = prior_k options["num_frames"] = T discretization = denoiser.discretization torch.cuda.empty_cache() seed_everything(seed) # Get Data input_indices = image_cond["input_indices"] input_imgs = imgs[input_indices] input_imgs_clip = imgs_clip[input_indices] input_c2ws = camera_cond["c2w"][input_indices] input_Ks = camera_cond["K"][input_indices] test_indices = [i for i in range(len(imgs)) if i not in input_indices] test_imgs = imgs[test_indices] test_imgs_clip = imgs_clip[test_indices] test_c2ws = camera_cond["c2w"][test_indices] test_Ks = camera_cond["K"][test_indices] if options.get("save_input", True): save_output( {"/image": input_imgs}, save_path=os.path.join(save_path, "input"), video_save_fps=2, ) if not use_traj_prior: chunk_strategy = options.get("chunk_strategy", "gt") ( _, input_inds_per_chunk, input_sels_per_chunk, test_inds_per_chunk, test_sels_per_chunk, ) = chunk_input_and_test( T, input_c2ws, test_c2ws, input_indices, test_indices, options=options, task=task, chunk_strategy=chunk_strategy, gt_input_inds=list(range(input_c2ws.shape[0])), ) print( f"One pass - chunking with `{chunk_strategy}` strategy: total " f"{len(input_inds_per_chunk)} forward(s) ..." ) all_samples = {} all_test_inds = [] for i, ( chunk_input_inds, chunk_input_sels, chunk_test_inds, chunk_test_sels, ) in tqdm( enumerate( zip( input_inds_per_chunk, input_sels_per_chunk, test_inds_per_chunk, test_sels_per_chunk, ) ), total=len(input_inds_per_chunk), leave=False, ): ( curr_input_sels, curr_test_sels, curr_input_maps, curr_test_maps, ) = pad_indices( chunk_input_sels, chunk_test_sels, T=T, padding_mode=options.get("t_padding_mode", "last"), ) curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [ assemble( input=x[chunk_input_inds], test=y[chunk_test_inds], input_maps=curr_input_maps, test_maps=curr_test_maps, ) for x, y in zip( [ torch.cat( [ input_imgs, get_k_from_dict(all_samples, "samples-rgb").to( input_imgs.device ), ], dim=0, ), torch.cat( [ input_imgs_clip, get_k_from_dict(all_samples, "samples-rgb").to( input_imgs.device ), ], dim=0, ), torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0), torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0), ], # procedually append generated prior views to the input views [test_imgs, test_imgs_clip, test_c2ws, test_Ks], ) ] value_dict = get_value_dict( curr_imgs.to("cuda"), curr_imgs_clip.to("cuda"), curr_input_sels + [ sel for (ind, sel) in zip( np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]], curr_test_sels, ) if test_indices[ind] in image_cond["input_indices"] ], curr_c2ws, curr_Ks, curr_input_sels + [ sel for (ind, sel) in zip( np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]], curr_test_sels, ) if test_indices[ind] in camera_cond["input_indices"] ], all_c2ws=camera_cond["c2w"], ) samplers = create_samplers( options["guider_types"], discretization, [len(curr_imgs)], options["num_steps"], options["cfg_min"], abort_event=abort_event, ) assert len(samplers) == 1 samples = do_sample( model, ae, conditioner, denoiser, samplers[0], value_dict, H, W, C, F, T=len(curr_imgs), cfg=( options["cfg"][0] if isinstance(options["cfg"], (list, tuple)) else options["cfg"] ), **{k: options[k] for k in options if k not in ["cfg", "T"]}, ) samples = decode_output( samples, len(curr_imgs), chunk_test_sels ) # decode into dict if options.get("save_first_pass", False): save_output( replace_or_include_input_for_dict( samples, chunk_test_sels, curr_imgs, curr_c2ws, curr_Ks, ), save_path=os.path.join(save_path, "first-pass", f"forward_{i}"), video_save_fps=2, ) extend_dict(all_samples, samples) all_test_inds.extend(chunk_test_inds) else: assert traj_prior_c2ws is not None, ( "`traj_prior_c2ws` should be set when using 2-pass sampling. One " "potential reason is that the amount of input frames is larger than " "T. Set `num_prior_frames` manually to overwrite the infered stats." ) traj_prior_c2ws = torch.as_tensor( traj_prior_c2ws, device=input_c2ws.device, dtype=input_c2ws.dtype, ) if traj_prior_Ks is None: traj_prior_Ks = test_Ks[:1].repeat_interleave( traj_prior_c2ws.shape[0], dim=0 ) traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:]) traj_prior_imgs_clip = imgs_clip.new_zeros( traj_prior_c2ws.shape[0], *imgs_clip.shape[1:] ) # ---------------------------------- first pass ---------------------------------- T_first_pass = T[0] if isinstance(T, (list, tuple)) else T T_second_pass = T[1] if isinstance(T, (list, tuple)) else T chunk_strategy_first_pass = options.get( "chunk_strategy_first_pass", "gt-nearest" ) ( _, input_inds_per_chunk, input_sels_per_chunk, prior_inds_per_chunk, prior_sels_per_chunk, ) = chunk_input_and_test( T_first_pass, input_c2ws, traj_prior_c2ws, input_indices, image_cond["prior_indices"], options=options, task=task, chunk_strategy=chunk_strategy_first_pass, gt_input_inds=list(range(input_c2ws.shape[0])), ) print( f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total " f"{len(input_inds_per_chunk)} forward(s) ..." ) all_samples = {} all_prior_inds = [] for i, ( chunk_input_inds, chunk_input_sels, chunk_prior_inds, chunk_prior_sels, ) in tqdm( enumerate( zip( input_inds_per_chunk, input_sels_per_chunk, prior_inds_per_chunk, prior_sels_per_chunk, ) ), total=len(input_inds_per_chunk), leave=False, ): ( curr_input_sels, curr_prior_sels, curr_input_maps, curr_prior_maps, ) = pad_indices( chunk_input_sels, chunk_prior_sels, T=T_first_pass, padding_mode=options.get("t_padding_mode", "last"), ) curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [ assemble( input=x[chunk_input_inds], test=y[chunk_prior_inds], input_maps=curr_input_maps, test_maps=curr_prior_maps, ) for x, y in zip( [ torch.cat( [ input_imgs, get_k_from_dict(all_samples, "samples-rgb").to( input_imgs.device ), ], dim=0, ), torch.cat( [ input_imgs_clip, get_k_from_dict(all_samples, "samples-rgb").to( input_imgs.device ), ], dim=0, ), torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0), torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0), ], # procedually append generated prior views to the input views [ traj_prior_imgs, traj_prior_imgs_clip, traj_prior_c2ws, traj_prior_Ks, ], ) ] value_dict = get_value_dict( curr_imgs.to("cuda"), curr_imgs_clip.to("cuda"), curr_input_sels, curr_c2ws, curr_Ks, list(range(T_first_pass)), all_c2ws=camera_cond["c2w"], # traj_prior_c2ws, ) samplers = create_samplers( options["guider_types"], discretization, [T_first_pass, T_second_pass], options["num_steps"], options["cfg_min"], abort_event=abort_event, ) samples = do_sample( model, ae, conditioner, denoiser, ( samplers[1] if len(samplers) > 1 and options.get("ltr_first_pass", False) and chunk_strategy_first_pass != "gt" and i > 0 else samplers[0] ), value_dict, H, W, C, F, cfg=( options["cfg"][0] if isinstance(options["cfg"], (list, tuple)) else options["cfg"] ), T=T_first_pass, global_pbar=first_pass_pbar, **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]}, ) if samples is None: return samples = decode_output( samples, T_first_pass, chunk_prior_sels ) # decode into dict extend_dict(all_samples, samples) all_prior_inds.extend(chunk_prior_inds) if options.get("save_first_pass", True): save_output( all_samples, save_path=os.path.join(save_path, "first-pass"), video_save_fps=5, ) video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4") yield video_path_0 # ---------------------------------- second pass ---------------------------------- prior_indices = image_cond["prior_indices"] assert ( prior_indices is not None ), "`prior_frame_indices` needs to be set if using 2-pass sampling." prior_argsort = np.argsort(input_indices + prior_indices).tolist() prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist() gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])] traj_prior_imgs = torch.cat( [input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0 )[prior_argsort] traj_prior_imgs_clip = torch.cat( [ input_imgs_clip, get_k_from_dict(all_samples, "samples-rgb"), ], dim=0, )[prior_argsort] traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort] traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort] update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs) update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws) update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks) chunk_strategy = options.get("chunk_strategy", "nearest") ( _, prior_inds_per_chunk, prior_sels_per_chunk, test_inds_per_chunk, test_sels_per_chunk, ) = chunk_input_and_test( T_second_pass, traj_prior_c2ws, test_c2ws, prior_indices, test_indices, options=options, task=task, chunk_strategy=chunk_strategy, gt_input_inds=gt_input_inds, ) print( f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total " f"{len(prior_inds_per_chunk)} forward(s) ..." ) all_samples = {} all_test_inds = [] for i, ( chunk_prior_inds, chunk_prior_sels, chunk_test_inds, chunk_test_sels, ) in tqdm( enumerate( zip( prior_inds_per_chunk, prior_sels_per_chunk, test_inds_per_chunk, test_sels_per_chunk, ) ), total=len(prior_inds_per_chunk), leave=False, ): ( curr_prior_sels, curr_test_sels, curr_prior_maps, curr_test_maps, ) = pad_indices( chunk_prior_sels, chunk_test_sels, T=T_second_pass, padding_mode="last", ) curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [ assemble( input=x[chunk_prior_inds], test=y[chunk_test_inds], input_maps=curr_prior_maps, test_maps=curr_test_maps, ) for x, y in zip( [ traj_prior_imgs, traj_prior_imgs_clip, traj_prior_c2ws, traj_prior_Ks, ], [test_imgs, test_imgs_clip, test_c2ws, test_Ks], ) ] value_dict = get_value_dict( curr_imgs.to("cuda"), curr_imgs_clip.to("cuda"), curr_prior_sels, curr_c2ws, curr_Ks, list(range(T_second_pass)), all_c2ws=camera_cond["c2w"], # test_c2ws, ) samples = do_sample( model, ae, conditioner, denoiser, samplers[1] if len(samplers) > 1 else samplers[0], value_dict, H, W, C, F, T=T_second_pass, cfg=( options["cfg"][1] if isinstance(options["cfg"], (list, tuple)) and len(options["cfg"]) > 1 else options["cfg"] ), global_pbar=second_pass_pbar, **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]}, ) if samples is None: return samples = decode_output( samples, T_second_pass, chunk_test_sels ) # decode into dict if options.get("save_second_pass", False): save_output( replace_or_include_input_for_dict( samples, chunk_test_sels, curr_imgs, curr_c2ws, curr_Ks, ), save_path=os.path.join(save_path, "second-pass", f"forward_{i}"), video_save_fps=2, ) extend_dict(all_samples, samples) all_test_inds.extend(chunk_test_inds) all_samples = { key: value[np.argsort(all_test_inds)] for key, value in all_samples.items() } save_output( replace_or_include_input_for_dict( all_samples, test_indices, imgs.clone(), camera_cond["c2w"].clone(), camera_cond["K"].clone(), ) if options.get("replace_or_include_input", False) else all_samples, save_path=save_path, video_save_fps=options.get("video_save_fps", 2), ) video_path_1 = os.path.join(save_path, "samples-rgb.mp4") yield video_path_1