import numpy as np import torch import torch.nn as nn from einops import rearrange from tqdm import tqdm from seva.geometry import get_camera_dist def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim if dims_to_append < 0: raise ValueError( f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" ) return x[(...,) + (None,) * dims_to_append] def append_zero(x: torch.Tensor) -> torch.Tensor: return torch.cat([x, x.new_zeros([1])]) def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: return (x - denoised) / append_dims(sigma, x.ndim) def make_betas( num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 ) -> np.ndarray: betas = ( torch.linspace( linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 ) ** 2 ) return betas.numpy() def generate_roughly_equally_spaced_steps( num_substeps: int, max_step: int ) -> np.ndarray: return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] class EpsScaling(object): def __call__( self, sigma: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: c_skip = torch.ones_like(sigma, device=sigma.device) c_out = -sigma c_in = 1 / (sigma**2 + 1.0) ** 0.5 c_noise = sigma.clone() return c_skip, c_out, c_in, c_noise class DDPMDiscretization(object): def __init__( self, linear_start: float = 5e-06, linear_end: float = 0.012, num_timesteps: int = 1000, log_snr_shift: float | None = 2.4, ): self.num_timesteps = num_timesteps betas = make_betas( num_timesteps, linear_start=linear_start, linear_end=linear_end, ) self.log_snr_shift = log_snr_shift alphas = 1.0 - betas # first alpha here is on data side self.alphas_cumprod = np.cumprod(alphas, axis=0) def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor: if n < self.num_timesteps: timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) alphas_cumprod = self.alphas_cumprod[timesteps] elif n == self.num_timesteps: alphas_cumprod = self.alphas_cumprod else: raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 if self.log_snr_shift is not None: sigmas = sigmas * np.exp(self.log_snr_shift) return torch.flip( torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) ) def __call__( self, n: int, do_append_zero: bool = True, flip: bool = False, device: str | torch.device = "cpu", ) -> torch.Tensor: sigmas = self.get_sigmas(n, device=device) sigmas = append_zero(sigmas) if do_append_zero else sigmas return sigmas if not flip else torch.flip(sigmas, (0,)) class DiscreteDenoiser(object): sigmas: torch.Tensor def __init__( self, discretization: DDPMDiscretization, num_idx: int = 1000, device: str | torch.device = "cpu", ): self.scaling = EpsScaling() self.discretization = discretization self.num_idx = num_idx self.device = device self.register_sigmas() def register_sigmas(self): self.sigmas = self.discretization( self.num_idx, do_append_zero=False, flip=True, device=self.device ) def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: dists = sigma - self.sigmas[:, None] return dists.abs().argmin(dim=0).view(sigma.shape) def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor: return self.sigmas[idx] def __call__( self, network: nn.Module, input: torch.Tensor, sigma: torch.Tensor, cond: dict, **additional_model_inputs, ) -> torch.Tensor: sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) sigma_shape = sigma.shape sigma = append_dims(sigma, input.ndim) c_skip, c_out, c_in, c_noise = self.scaling(sigma) c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) if "replace" in cond: x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1) input = input * (1 - mask) + x * mask return ( network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out + input * c_skip ) class ConstantScaleRule(object): def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor: return scale class MultiviewScaleRule(object): def __init__(self, min_scale: float = 1.0): self.min_scale = min_scale def __call__( self, scale: float | torch.Tensor, c2w: torch.Tensor, K: torch.Tensor, input_frame_mask: torch.Tensor, ) -> torch.Tensor: c2w_input = c2w[input_frame_mask] rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values translation_diff = ( get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values ) K_diff = ( ((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) ) close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff if isinstance(scale, torch.Tensor): scale = scale.clone() scale[close_frame] = self.min_scale elif isinstance(scale, float): scale = torch.where(close_frame, self.min_scale, scale) else: raise ValueError(f"Invalid scale type {type(scale)}.") return scale class ConstantScaleSchedule(object): def __call__( self, sigma: float | torch.Tensor, scale: float | torch.Tensor ) -> float | torch.Tensor: if isinstance(sigma, float): return scale elif isinstance(sigma, torch.Tensor): if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): sigma = append_dims(sigma, scale.ndim) return scale * torch.ones_like(sigma) else: raise ValueError(f"Invalid sigma type {type(sigma)}.") class ConstantGuidance(object): def __call__( self, uncond: torch.Tensor, cond: torch.Tensor, scale: float | torch.Tensor, ) -> torch.Tensor: if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: scale = append_dims(scale, cond.ndim) return uncond + scale * (cond - uncond) class VanillaCFG(object): def __init__(self): self.scale_rule = ConstantScaleRule() self.scale_schedule = ConstantScaleSchedule() self.guidance = ConstantGuidance() def __call__( self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor ) -> torch.Tensor: x_u, x_c = x.chunk(2) scale = self.scale_rule(scale) scale_value = self.scale_schedule(sigma, scale) x_pred = self.guidance(x_u, x_c, scale_value) return x_pred def prepare_inputs( self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict ) -> tuple[torch.Tensor, torch.Tensor, dict]: c_out = dict() for k in c: if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: c_out[k] = torch.cat((uc[k], c[k]), 0) else: assert c[k] == uc[k] c_out[k] = c[k] return torch.cat([x] * 2), torch.cat([s] * 2), c_out class MultiviewCFG(VanillaCFG): def __init__(self, cfg_min: float = 1.0): self.scale_min = cfg_min self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) self.scale_schedule = ConstantScaleSchedule() self.guidance = ConstantGuidance() def __call__( # type: ignore self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor, c2w: torch.Tensor, K: torch.Tensor, input_frame_mask: torch.Tensor, ) -> torch.Tensor: x_u, x_c = x.chunk(2) scale = self.scale_rule(scale, c2w, K, input_frame_mask) scale_value = self.scale_schedule(sigma, scale) x_pred = self.guidance(x_u, x_c, scale_value) return x_pred class MultiviewTemporalCFG(MultiviewCFG): def __init__(self, num_frames: int, cfg_min: float = 1.0): super().__init__(cfg_min=cfg_min) self.num_frames = num_frames distance_matrix = ( torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] ).abs() self.distance_matrix = distance_matrix def __call__( self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor, c2w: torch.Tensor, K: torch.Tensor, input_frame_mask: torch.Tensor, ) -> torch.Tensor: input_frame_mask = rearrange( input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames ) min_distance = ( self.distance_matrix[None].to(x.device) + (~input_frame_mask[:, None]) * self.num_frames ).min(-1)[0] min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) scale = min_distance * (scale - self.scale_min) + self.scale_min scale = rearrange(scale, "b t ... -> (b t) ...") scale = append_dims(scale, x.ndim) return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) class EulerEDMSampler(object): def __init__( self, discretization: DDPMDiscretization, guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG, num_steps: int | None = None, verbose: bool = False, device: str | torch.device = "cuda", s_churn=0.0, s_tmin=0.0, s_tmax=float("inf"), s_noise=1.0, ): self.num_steps = num_steps self.discretization = discretization self.guider = guider self.verbose = verbose self.device = device self.s_churn = s_churn self.s_tmin = s_tmin self.s_tmax = s_tmax self.s_noise = s_noise def prepare_sampling_loop( self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: num_steps = num_steps or self.num_steps assert num_steps is not None, "num_steps must be specified" sigmas = self.discretization(num_steps, device=self.device) x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) num_sigmas = len(sigmas) s_in = x.new_ones([x.shape[0]]) return x, s_in, sigmas, num_sigmas, cond, uc def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm: sigma_generator = range(num_sigmas - 1) if self.verbose and verbose: sigma_generator = tqdm( sigma_generator, total=num_sigmas - 1, desc="Sampling", leave=False, ) return sigma_generator def sampler_step( self, sigma: torch.Tensor, next_sigma: torch.Tensor, denoiser, x: torch.Tensor, scale: float | torch.Tensor, cond: dict, uc: dict, gamma: float = 0.0, **guider_kwargs, ) -> torch.Tensor: sigma_hat = sigma * (gamma + 1.0) + 1e-6 eps = torch.randn_like(x) * self.s_noise x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) d = to_d(x, sigma_hat, denoised) dt = append_dims(next_sigma - sigma_hat, x.ndim) return x + dt * d def __call__( self, denoiser, x: torch.Tensor, scale: float | torch.Tensor, cond: dict, uc: dict | None = None, num_steps: int | None = None, verbose: bool = True, **guider_kwargs, ) -> torch.Tensor: uc = cond if uc is None else uc x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( x, cond, uc, num_steps, ) for i in self.get_sigma_gen(num_sigmas, verbose=verbose): gamma = ( min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) if self.s_tmin <= sigmas[i] <= self.s_tmax else 0.0 ) x = self.sampler_step( s_in * sigmas[i], s_in * sigmas[i + 1], denoiser, x, scale, cond, uc, gamma, **guider_kwargs, ) return x