Spaces:
Runtime error
Runtime error
File size: 4,844 Bytes
e276be2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
from abc import abstractmethod
from functools import partial
import numpy as np
import torch
from ...modules.diffusionmodules.util import make_beta_schedule
from ...util import append_zero
def generate_roughly_equally_spaced_steps(num_substeps: int, max_step: int) -> np.ndarray:
return np.linspace(max_step - 1, 0, num_substeps, endpoint=False).astype(int)[::-1]
class Discretization:
def __call__(self, n, do_append_zero=True, device="cpu", flip=False, return_idx=False):
if return_idx:
sigmas, idx = self.get_sigmas(n, device=device, return_idx=return_idx)
else:
sigmas = self.get_sigmas(n, device=device, return_idx=return_idx)
sigmas = append_zero(sigmas) if do_append_zero else sigmas
if return_idx:
return sigmas if not flip else torch.flip(sigmas, (0,)), idx
else:
return sigmas if not flip else torch.flip(sigmas, (0,))
@abstractmethod
def get_sigmas(self, n, device):
pass
class EDMDiscretization(Discretization):
def __init__(self, sigma_min=0.002, sigma_max=80.0, rho=7.0):
self.sigma_min = sigma_min
self.sigma_max = sigma_max
self.rho = rho
def get_sigmas(self, n, device="cpu"):
ramp = torch.linspace(0, 1, n, device=device)
min_inv_rho = self.sigma_min ** (1 / self.rho)
max_inv_rho = self.sigma_max ** (1 / self.rho)
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** self.rho
return sigmas
class LegacyDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
):
super().__init__()
self.num_timesteps = num_timesteps
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
def get_sigmas(self, n, device="cpu"):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
sigmas = to_torch((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
return torch.flip(sigmas, (0,)) # sigma_t: 14.4 -> 0.029
class ZeroSNRDDPMDiscretization(Discretization):
def __init__(
self,
linear_start=0.00085,
linear_end=0.0120,
num_timesteps=1000,
shift_scale=1.0, # noise schedule t_n -> t_m: logSNR(t_m) = logSNR(t_n) - log(shift_scale)
keep_start=False,
post_shift=False,
):
super().__init__()
if keep_start and not post_shift:
linear_start = linear_start / (shift_scale + (1 - shift_scale) * linear_start)
self.num_timesteps = num_timesteps
betas = make_beta_schedule("linear", num_timesteps, linear_start=linear_start, linear_end=linear_end)
alphas = 1.0 - betas
self.alphas_cumprod = np.cumprod(alphas, axis=0)
self.to_torch = partial(torch.tensor, dtype=torch.float32)
# SNR shift
if not post_shift:
self.alphas_cumprod = self.alphas_cumprod / (shift_scale + (1 - shift_scale) * self.alphas_cumprod)
self.post_shift = post_shift
self.shift_scale = shift_scale
def get_sigmas(self, n, device="cpu", return_idx=False):
if n < self.num_timesteps:
timesteps = generate_roughly_equally_spaced_steps(n, self.num_timesteps)
alphas_cumprod = self.alphas_cumprod[timesteps]
elif n == self.num_timesteps:
alphas_cumprod = self.alphas_cumprod
else:
raise ValueError
to_torch = partial(torch.tensor, dtype=torch.float32, device=device)
alphas_cumprod = to_torch(alphas_cumprod)
alphas_cumprod_sqrt = alphas_cumprod.sqrt()
alphas_cumprod_sqrt_0 = alphas_cumprod_sqrt[0].clone()
alphas_cumprod_sqrt_T = alphas_cumprod_sqrt[-1].clone()
alphas_cumprod_sqrt -= alphas_cumprod_sqrt_T
alphas_cumprod_sqrt *= alphas_cumprod_sqrt_0 / (alphas_cumprod_sqrt_0 - alphas_cumprod_sqrt_T)
if self.post_shift:
alphas_cumprod_sqrt = (
alphas_cumprod_sqrt**2 / (self.shift_scale + (1 - self.shift_scale) * alphas_cumprod_sqrt**2)
) ** 0.5
if return_idx:
return torch.flip(alphas_cumprod_sqrt, (0,)), timesteps
else:
return torch.flip(alphas_cumprod_sqrt, (0,)) # sqrt(alpha_t): 0 -> 0.99
|