Spaces:
Running
on
L40S
Running
on
L40S
import numpy as np | |
import torch | |
import torch.nn as nn | |
from einops import rearrange | |
from tqdm import tqdm | |
from seva.geometry import get_camera_dist | |
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: | |
"""Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
dims_to_append = target_dims - x.ndim | |
if dims_to_append < 0: | |
raise ValueError( | |
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
) | |
return x[(...,) + (None,) * dims_to_append] | |
def append_zero(x: torch.Tensor) -> torch.Tensor: | |
return torch.cat([x, x.new_zeros([1])]) | |
def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor: | |
return (x - denoised) / append_dims(sigma, x.ndim) | |
def make_betas( | |
num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2 | |
) -> np.ndarray: | |
betas = ( | |
torch.linspace( | |
linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64 | |
) | |
** 2 | |
) | |
return betas.numpy() | |
def generate_roughly_equally_spaced_steps( | |
num_substeps: int, max_step: int | |
) -> np.ndarray: | |
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1] | |
class EpsScaling(object): | |
def __call__( | |
self, sigma: torch.Tensor | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
c_skip = torch.ones_like(sigma, device=sigma.device) | |
c_out = -sigma | |
c_in = 1 / (sigma**2 + 1.0) ** 0.5 | |
c_noise = sigma.clone() | |
return c_skip, c_out, c_in, c_noise | |
class DDPMDiscretization(object): | |
def __init__( | |
self, | |
linear_start: float = 5e-06, | |
linear_end: float = 0.012, | |
num_timesteps: int = 1000, | |
log_snr_shift: float | None = 2.4, | |
): | |
self.num_timesteps = num_timesteps | |
betas = make_betas( | |
num_timesteps, | |
linear_start=linear_start, | |
linear_end=linear_end, | |
) | |
self.log_snr_shift = log_snr_shift | |
alphas = 1.0 - betas # first alpha here is on data side | |
self.alphas_cumprod = np.cumprod(alphas, axis=0) | |
def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor: | |
if n < self.num_timesteps: | |
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps) | |
alphas_cumprod = self.alphas_cumprod[timesteps] | |
elif n == self.num_timesteps: | |
alphas_cumprod = self.alphas_cumprod | |
else: | |
raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.") | |
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 | |
if self.log_snr_shift is not None: | |
sigmas = sigmas * np.exp(self.log_snr_shift) | |
return torch.flip( | |
torch.tensor(sigmas, dtype=torch.float32, device=device), (0,) | |
) | |
def __call__( | |
self, | |
n: int, | |
do_append_zero: bool = True, | |
flip: bool = False, | |
device: str | torch.device = "cpu", | |
) -> torch.Tensor: | |
sigmas = self.get_sigmas(n, device=device) | |
sigmas = append_zero(sigmas) if do_append_zero else sigmas | |
return sigmas if not flip else torch.flip(sigmas, (0,)) | |
class DiscreteDenoiser(object): | |
sigmas: torch.Tensor | |
def __init__( | |
self, | |
discretization: DDPMDiscretization, | |
num_idx: int = 1000, | |
device: str | torch.device = "cpu", | |
): | |
self.scaling = EpsScaling() | |
self.discretization = discretization | |
self.num_idx = num_idx | |
self.device = device | |
self.register_sigmas() | |
def register_sigmas(self): | |
self.sigmas = self.discretization( | |
self.num_idx, do_append_zero=False, flip=True, device=self.device | |
) | |
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor: | |
dists = sigma - self.sigmas[:, None] | |
return dists.abs().argmin(dim=0).view(sigma.shape) | |
def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor: | |
return self.sigmas[idx] | |
def __call__( | |
self, | |
network: nn.Module, | |
input: torch.Tensor, | |
sigma: torch.Tensor, | |
cond: dict, | |
**additional_model_inputs, | |
) -> torch.Tensor: | |
sigma = self.idx_to_sigma(self.sigma_to_idx(sigma)) | |
sigma_shape = sigma.shape | |
sigma = append_dims(sigma, input.ndim) | |
c_skip, c_out, c_in, c_noise = self.scaling(sigma) | |
c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape)) | |
if "replace" in cond: | |
x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1) | |
input = input * (1 - mask) + x * mask | |
return ( | |
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out | |
+ input * c_skip | |
) | |
class ConstantScaleRule(object): | |
def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor: | |
return scale | |
class MultiviewScaleRule(object): | |
def __init__(self, min_scale: float = 1.0): | |
self.min_scale = min_scale | |
def __call__( | |
self, | |
scale: float | torch.Tensor, | |
c2w: torch.Tensor, | |
K: torch.Tensor, | |
input_frame_mask: torch.Tensor, | |
) -> torch.Tensor: | |
c2w_input = c2w[input_frame_mask] | |
rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values | |
translation_diff = ( | |
get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values | |
) | |
K_diff = ( | |
((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1) | |
) | |
close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff | |
if isinstance(scale, torch.Tensor): | |
scale = scale.clone() | |
scale[close_frame] = self.min_scale | |
elif isinstance(scale, float): | |
scale = torch.where(close_frame, self.min_scale, scale) | |
else: | |
raise ValueError(f"Invalid scale type {type(scale)}.") | |
return scale | |
class ConstantScaleSchedule(object): | |
def __call__( | |
self, sigma: float | torch.Tensor, scale: float | torch.Tensor | |
) -> float | torch.Tensor: | |
if isinstance(sigma, float): | |
return scale | |
elif isinstance(sigma, torch.Tensor): | |
if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor): | |
sigma = append_dims(sigma, scale.ndim) | |
return scale * torch.ones_like(sigma) | |
else: | |
raise ValueError(f"Invalid sigma type {type(sigma)}.") | |
class ConstantGuidance(object): | |
def __call__( | |
self, | |
uncond: torch.Tensor, | |
cond: torch.Tensor, | |
scale: float | torch.Tensor, | |
) -> torch.Tensor: | |
if isinstance(scale, torch.Tensor) and len(scale.shape) == 1: | |
scale = append_dims(scale, cond.ndim) | |
return uncond + scale * (cond - uncond) | |
class VanillaCFG(object): | |
def __init__(self): | |
self.scale_rule = ConstantScaleRule() | |
self.scale_schedule = ConstantScaleSchedule() | |
self.guidance = ConstantGuidance() | |
def __call__( | |
self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor | |
) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
scale = self.scale_rule(scale) | |
scale_value = self.scale_schedule(sigma, scale) | |
x_pred = self.guidance(x_u, x_c, scale_value) | |
return x_pred | |
def prepare_inputs( | |
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict | |
) -> tuple[torch.Tensor, torch.Tensor, dict]: | |
c_out = dict() | |
for k in c: | |
if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]: | |
c_out[k] = torch.cat((uc[k], c[k]), 0) | |
else: | |
assert c[k] == uc[k] | |
c_out[k] = c[k] | |
return torch.cat([x] * 2), torch.cat([s] * 2), c_out | |
class MultiviewCFG(VanillaCFG): | |
def __init__(self, cfg_min: float = 1.0): | |
self.scale_min = cfg_min | |
self.scale_rule = MultiviewScaleRule(min_scale=cfg_min) | |
self.scale_schedule = ConstantScaleSchedule() | |
self.guidance = ConstantGuidance() | |
def __call__( # type: ignore | |
self, | |
x: torch.Tensor, | |
sigma: float | torch.Tensor, | |
scale: float | torch.Tensor, | |
c2w: torch.Tensor, | |
K: torch.Tensor, | |
input_frame_mask: torch.Tensor, | |
) -> torch.Tensor: | |
x_u, x_c = x.chunk(2) | |
scale = self.scale_rule(scale, c2w, K, input_frame_mask) | |
scale_value = self.scale_schedule(sigma, scale) | |
x_pred = self.guidance(x_u, x_c, scale_value) | |
return x_pred | |
class MultiviewTemporalCFG(MultiviewCFG): | |
def __init__(self, num_frames: int, cfg_min: float = 1.0): | |
super().__init__(cfg_min=cfg_min) | |
self.num_frames = num_frames | |
distance_matrix = ( | |
torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None] | |
).abs() | |
self.distance_matrix = distance_matrix | |
def __call__( | |
self, | |
x: torch.Tensor, | |
sigma: float | torch.Tensor, | |
scale: float | torch.Tensor, | |
c2w: torch.Tensor, | |
K: torch.Tensor, | |
input_frame_mask: torch.Tensor, | |
) -> torch.Tensor: | |
input_frame_mask = rearrange( | |
input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames | |
) | |
min_distance = ( | |
self.distance_matrix[None].to(x.device) | |
+ (~input_frame_mask[:, None]) * self.num_frames | |
).min(-1)[0] | |
min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1) | |
scale = min_distance * (scale - self.scale_min) + self.scale_min | |
scale = rearrange(scale, "b t ... -> (b t) ...") | |
scale = append_dims(scale, x.ndim) | |
return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1)) | |
class EulerEDMSampler(object): | |
def __init__( | |
self, | |
discretization: DDPMDiscretization, | |
guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG, | |
num_steps: int | None = None, | |
verbose: bool = False, | |
device: str | torch.device = "cuda", | |
s_churn=0.0, | |
s_tmin=0.0, | |
s_tmax=float("inf"), | |
s_noise=1.0, | |
): | |
self.num_steps = num_steps | |
self.discretization = discretization | |
self.guider = guider | |
self.verbose = verbose | |
self.device = device | |
self.s_churn = s_churn | |
self.s_tmin = s_tmin | |
self.s_tmax = s_tmax | |
self.s_noise = s_noise | |
def prepare_sampling_loop( | |
self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]: | |
num_steps = num_steps or self.num_steps | |
assert num_steps is not None, "num_steps must be specified" | |
sigmas = self.discretization(num_steps, device=self.device) | |
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0) | |
num_sigmas = len(sigmas) | |
s_in = x.new_ones([x.shape[0]]) | |
return x, s_in, sigmas, num_sigmas, cond, uc | |
def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm: | |
sigma_generator = range(num_sigmas - 1) | |
if self.verbose and verbose: | |
sigma_generator = tqdm( | |
sigma_generator, | |
total=num_sigmas - 1, | |
desc="Sampling", | |
leave=False, | |
) | |
return sigma_generator | |
def sampler_step( | |
self, | |
sigma: torch.Tensor, | |
next_sigma: torch.Tensor, | |
denoiser, | |
x: torch.Tensor, | |
scale: float | torch.Tensor, | |
cond: dict, | |
uc: dict, | |
gamma: float = 0.0, | |
**guider_kwargs, | |
) -> torch.Tensor: | |
sigma_hat = sigma * (gamma + 1.0) + 1e-6 | |
eps = torch.randn_like(x) * self.s_noise | |
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5 | |
denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc)) | |
denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs) | |
d = to_d(x, sigma_hat, denoised) | |
dt = append_dims(next_sigma - sigma_hat, x.ndim) | |
return x + dt * d | |
def __call__( | |
self, | |
denoiser, | |
x: torch.Tensor, | |
scale: float | torch.Tensor, | |
cond: dict, | |
uc: dict | None = None, | |
num_steps: int | None = None, | |
verbose: bool = True, | |
**guider_kwargs, | |
) -> torch.Tensor: | |
uc = cond if uc is None else uc | |
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop( | |
x, | |
cond, | |
uc, | |
num_steps, | |
) | |
for i in self.get_sigma_gen(num_sigmas, verbose=verbose): | |
gamma = ( | |
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1) | |
if self.s_tmin <= sigmas[i] <= self.s_tmax | |
else 0.0 | |
) | |
x = self.sampler_step( | |
s_in * sigmas[i], | |
s_in * sigmas[i + 1], | |
denoiser, | |
x, | |
scale, | |
cond, | |
uc, | |
gamma, | |
**guider_kwargs, | |
) | |
return x | |