Spaces:
Running
on
L40S
Running
on
L40S
| import numpy as np | |
| import torch | |
| def append_dims(x, target_dims): | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less") | |
| return x[(...,) + (None,) * dims_to_append] | |
| # From LCMScheduler.get_scalings_for_boundary_condition_discrete | |
| def scalings_for_boundary_conditions(timestep, sigma_data=0.5, timestep_scaling=10.0): | |
| scaled_timestep = timestep_scaling * timestep | |
| c_skip = sigma_data**2 / (scaled_timestep**2 + sigma_data**2) | |
| c_out = scaled_timestep / (scaled_timestep**2 + sigma_data**2) ** 0.5 | |
| return c_skip, c_out | |
| def extract_into_tensor(a, t, x_shape): | |
| b, *_ = t.shape | |
| out = a.gather(-1, t) | |
| return out.reshape(b, *((1,) * (len(x_shape) - 1))) | |
| class DDIMSolver: | |
| def __init__(self, alpha_cumprods, timesteps=1000, ddim_timesteps=50): | |
| # DDIM sampling parameters | |
| step_ratio = timesteps // ddim_timesteps | |
| self.ddim_timesteps = (np.arange(1, ddim_timesteps + 1) * step_ratio).round().astype(np.int64) - 1 | |
| self.ddim_alpha_cumprods = alpha_cumprods[self.ddim_timesteps] | |
| self.ddim_alpha_cumprods_prev = np.asarray( | |
| [alpha_cumprods[0]] + alpha_cumprods[self.ddim_timesteps[:-1]].tolist() | |
| ) | |
| # convert to torch tensors | |
| self.ddim_timesteps = torch.from_numpy(self.ddim_timesteps).long() | |
| self.ddim_alpha_cumprods = torch.from_numpy(self.ddim_alpha_cumprods) | |
| self.ddim_alpha_cumprods_prev = torch.from_numpy(self.ddim_alpha_cumprods_prev) | |
| def to(self, device): | |
| self.ddim_timesteps = self.ddim_timesteps.to(device) | |
| self.ddim_alpha_cumprods = self.ddim_alpha_cumprods.to(device) | |
| self.ddim_alpha_cumprods_prev = self.ddim_alpha_cumprods_prev.to(device) | |
| return self | |
| def ddim_step(self, pred_x0, pred_noise, timestep_index): | |
| alpha_cumprod_prev = extract_into_tensor(self.ddim_alpha_cumprods_prev, timestep_index, pred_x0.shape) | |
| dir_xt = (1.0 - alpha_cumprod_prev).sqrt() * pred_noise | |
| x_prev = alpha_cumprod_prev.sqrt() * pred_x0 + dir_xt | |
| return x_prev |