hangg-sai's picture
Initial commit
a342aa8
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange
from tqdm import tqdm
from seva.geometry import get_camera_dist
def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor:
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
if dims_to_append < 0:
raise ValueError(
f"input has {x.ndim} dims but target_dims is {target_dims}, which is less"
)
return x[(...,) + (None,) * dims_to_append]
def append_zero(x: torch.Tensor) -> torch.Tensor:
return torch.cat([x, x.new_zeros([1])])
def to_d(x: torch.Tensor, sigma: torch.Tensor, denoised: torch.Tensor) -> torch.Tensor:
return (x - denoised) / append_dims(sigma, x.ndim)
def make_betas(
num_timesteps: int, linear_start: float = 1e-4, linear_end: float = 2e-2
) -> np.ndarray:
betas = (
torch.linspace(
linear_start**0.5, linear_end**0.5, num_timesteps, dtype=torch.float64
)
** 2
)
return betas.numpy()
def generate_roughly_equally_spaced_steps(
num_substeps: int, max_step: int
) -> np.ndarray:
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
class EpsScaling(object):
def __call__(
self, sigma: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
c_skip = torch.ones_like(sigma, device=sigma.device)
c_out = -sigma
c_in = 1 / (sigma**2 + 1.0) ** 0.5
c_noise = sigma.clone()
return c_skip, c_out, c_in, c_noise
class DDPMDiscretization(object):
def __init__(
self,
linear_start: float = 5e-06,
linear_end: float = 0.012,
num_timesteps: int = 1000,
log_snr_shift: float | None = 2.4,
):
self.num_timesteps = num_timesteps
betas = make_betas(
num_timesteps,
linear_start=linear_start,
linear_end=linear_end,
)
self.log_snr_shift = log_snr_shift
alphas = 1.0 - betas # first alpha here is on data side
self.alphas_cumprod = np.cumprod(alphas, axis=0)
def get_sigmas(self, n: int, device: str | torch.device = "cpu") -> torch.Tensor:
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError(f"Expected n <= {self.num_timesteps}, but got n = {n}.")
sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
if self.log_snr_shift is not None:
sigmas = sigmas * np.exp(self.log_snr_shift)
return torch.flip(
torch.tensor(sigmas, dtype=torch.float32, device=device), (0,)
)
def __call__(
self,
n: int,
do_append_zero: bool = True,
flip: bool = False,
device: str | torch.device = "cpu",
) -> torch.Tensor:
sigmas = self.get_sigmas(n, device=device)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
return sigmas if not flip else torch.flip(sigmas, (0,))
class DiscreteDenoiser(object):
sigmas: torch.Tensor
def __init__(
self,
discretization: DDPMDiscretization,
num_idx: int = 1000,
device: str | torch.device = "cpu",
):
self.scaling = EpsScaling()
self.discretization = discretization
self.num_idx = num_idx
self.device = device
self.register_sigmas()
def register_sigmas(self):
self.sigmas = self.discretization(
self.num_idx, do_append_zero=False, flip=True, device=self.device
)
def sigma_to_idx(self, sigma: torch.Tensor) -> torch.Tensor:
dists = sigma - self.sigmas[:, None]
return dists.abs().argmin(dim=0).view(sigma.shape)
def idx_to_sigma(self, idx: torch.Tensor | int) -> torch.Tensor:
return self.sigmas[idx]
def __call__(
self,
network: nn.Module,
input: torch.Tensor,
sigma: torch.Tensor,
cond: dict,
**additional_model_inputs,
) -> torch.Tensor:
sigma = self.idx_to_sigma(self.sigma_to_idx(sigma))
sigma_shape = sigma.shape
sigma = append_dims(sigma, input.ndim)
c_skip, c_out, c_in, c_noise = self.scaling(sigma)
c_noise = self.sigma_to_idx(c_noise.reshape(sigma_shape))
if "replace" in cond:
x, mask = cond.pop("replace").split((input.shape[1], 1), dim=1)
input = input * (1 - mask) + x * mask
return (
network(input * c_in, c_noise, cond, **additional_model_inputs) * c_out
+ input * c_skip
)
class ConstantScaleRule(object):
def __call__(self, scale: float | torch.Tensor) -> float | torch.Tensor:
return scale
class MultiviewScaleRule(object):
def __init__(self, min_scale: float = 1.0):
self.min_scale = min_scale
def __call__(
self,
scale: float | torch.Tensor,
c2w: torch.Tensor,
K: torch.Tensor,
input_frame_mask: torch.Tensor,
) -> torch.Tensor:
c2w_input = c2w[input_frame_mask]
rotation_diff = get_camera_dist(c2w, c2w_input, mode="rotation").min(-1).values
translation_diff = (
get_camera_dist(c2w, c2w_input, mode="translation").min(-1).values
)
K_diff = (
((K[:, None] - K[input_frame_mask][None]).flatten(-2) == 0).all(-1).any(-1)
)
close_frame = (rotation_diff < 10.0) & (translation_diff < 1e-5) & K_diff
if isinstance(scale, torch.Tensor):
scale = scale.clone()
scale[close_frame] = self.min_scale
elif isinstance(scale, float):
scale = torch.where(close_frame, self.min_scale, scale)
else:
raise ValueError(f"Invalid scale type {type(scale)}.")
return scale
class ConstantScaleSchedule(object):
def __call__(
self, sigma: float | torch.Tensor, scale: float | torch.Tensor
) -> float | torch.Tensor:
if isinstance(sigma, float):
return scale
elif isinstance(sigma, torch.Tensor):
if len(sigma.shape) == 1 and isinstance(scale, torch.Tensor):
sigma = append_dims(sigma, scale.ndim)
return scale * torch.ones_like(sigma)
else:
raise ValueError(f"Invalid sigma type {type(sigma)}.")
class ConstantGuidance(object):
def __call__(
self,
uncond: torch.Tensor,
cond: torch.Tensor,
scale: float | torch.Tensor,
) -> torch.Tensor:
if isinstance(scale, torch.Tensor) and len(scale.shape) == 1:
scale = append_dims(scale, cond.ndim)
return uncond + scale * (cond - uncond)
class VanillaCFG(object):
def __init__(self):
self.scale_rule = ConstantScaleRule()
self.scale_schedule = ConstantScaleSchedule()
self.guidance = ConstantGuidance()
def __call__(
self, x: torch.Tensor, sigma: float | torch.Tensor, scale: float | torch.Tensor
) -> torch.Tensor:
x_u, x_c = x.chunk(2)
scale = self.scale_rule(scale)
scale_value = self.scale_schedule(sigma, scale)
x_pred = self.guidance(x_u, x_c, scale_value)
return x_pred
def prepare_inputs(
self, x: torch.Tensor, s: torch.Tensor, c: dict, uc: dict
) -> tuple[torch.Tensor, torch.Tensor, dict]:
c_out = dict()
for k in c:
if k in ["vector", "crossattn", "concat", "replace", "dense_vector"]:
c_out[k] = torch.cat((uc[k], c[k]), 0)
else:
assert c[k] == uc[k]
c_out[k] = c[k]
return torch.cat([x] * 2), torch.cat([s] * 2), c_out
class MultiviewCFG(VanillaCFG):
def __init__(self, cfg_min: float = 1.0):
self.scale_min = cfg_min
self.scale_rule = MultiviewScaleRule(min_scale=cfg_min)
self.scale_schedule = ConstantScaleSchedule()
self.guidance = ConstantGuidance()
def __call__( # type: ignore
self,
x: torch.Tensor,
sigma: float | torch.Tensor,
scale: float | torch.Tensor,
c2w: torch.Tensor,
K: torch.Tensor,
input_frame_mask: torch.Tensor,
) -> torch.Tensor:
x_u, x_c = x.chunk(2)
scale = self.scale_rule(scale, c2w, K, input_frame_mask)
scale_value = self.scale_schedule(sigma, scale)
x_pred = self.guidance(x_u, x_c, scale_value)
return x_pred
class MultiviewTemporalCFG(MultiviewCFG):
def __init__(self, num_frames: int, cfg_min: float = 1.0):
super().__init__(cfg_min=cfg_min)
self.num_frames = num_frames
distance_matrix = (
torch.arange(num_frames)[None] - torch.arange(num_frames)[:, None]
).abs()
self.distance_matrix = distance_matrix
def __call__(
self,
x: torch.Tensor,
sigma: float | torch.Tensor,
scale: float | torch.Tensor,
c2w: torch.Tensor,
K: torch.Tensor,
input_frame_mask: torch.Tensor,
) -> torch.Tensor:
input_frame_mask = rearrange(
input_frame_mask, "(b t) ... -> b t ...", t=self.num_frames
)
min_distance = (
self.distance_matrix[None].to(x.device)
+ (~input_frame_mask[:, None]) * self.num_frames
).min(-1)[0]
min_distance = min_distance / min_distance.max(-1, keepdim=True)[0].clamp(min=1)
scale = min_distance * (scale - self.scale_min) + self.scale_min
scale = rearrange(scale, "b t ... -> (b t) ...")
scale = append_dims(scale, x.ndim)
return super().__call__(x, sigma, scale, c2w, K, input_frame_mask.flatten(0, 1))
class EulerEDMSampler(object):
def __init__(
self,
discretization: DDPMDiscretization,
guider: VanillaCFG | MultiviewCFG | MultiviewTemporalCFG,
num_steps: int | None = None,
verbose: bool = False,
device: str | torch.device = "cuda",
s_churn=0.0,
s_tmin=0.0,
s_tmax=float("inf"),
s_noise=1.0,
):
self.num_steps = num_steps
self.discretization = discretization
self.guider = guider
self.verbose = verbose
self.device = device
self.s_churn = s_churn
self.s_tmin = s_tmin
self.s_tmax = s_tmax
self.s_noise = s_noise
def prepare_sampling_loop(
self, x: torch.Tensor, cond: dict, uc: dict, num_steps: int | None = None
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, dict, dict]:
num_steps = num_steps or self.num_steps
assert num_steps is not None, "num_steps must be specified"
sigmas = self.discretization(num_steps, device=self.device)
x *= torch.sqrt(1.0 + sigmas[0] ** 2.0)
num_sigmas = len(sigmas)
s_in = x.new_ones([x.shape[0]])
return x, s_in, sigmas, num_sigmas, cond, uc
def get_sigma_gen(self, num_sigmas: int, verbose: bool = True) -> range | tqdm:
sigma_generator = range(num_sigmas - 1)
if self.verbose and verbose:
sigma_generator = tqdm(
sigma_generator,
total=num_sigmas - 1,
desc="Sampling",
leave=False,
)
return sigma_generator
def sampler_step(
self,
sigma: torch.Tensor,
next_sigma: torch.Tensor,
denoiser,
x: torch.Tensor,
scale: float | torch.Tensor,
cond: dict,
uc: dict,
gamma: float = 0.0,
**guider_kwargs,
) -> torch.Tensor:
sigma_hat = sigma * (gamma + 1.0) + 1e-6
eps = torch.randn_like(x) * self.s_noise
x = x + eps * append_dims(sigma_hat**2 - sigma**2, x.ndim) ** 0.5
denoised = denoiser(*self.guider.prepare_inputs(x, sigma_hat, cond, uc))
denoised = self.guider(denoised, sigma_hat, scale, **guider_kwargs)
d = to_d(x, sigma_hat, denoised)
dt = append_dims(next_sigma - sigma_hat, x.ndim)
return x + dt * d
def __call__(
self,
denoiser,
x: torch.Tensor,
scale: float | torch.Tensor,
cond: dict,
uc: dict | None = None,
num_steps: int | None = None,
verbose: bool = True,
**guider_kwargs,
) -> torch.Tensor:
uc = cond if uc is None else uc
x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
x,
cond,
uc,
num_steps,
)
for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
gamma = (
min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
if self.s_tmin <= sigmas[i] <= self.s_tmax
else 0.0
)
x = self.sampler_step(
s_in * sigmas[i],
s_in * sigmas[i + 1],
denoiser,
x,
scale,
cond,
uc,
gamma,
**guider_kwargs,
)
return x