# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. import numpy as np import scipy.signal import torch from torch_utils import persistence from torch_utils import misc from torch_utils.ops import upfirdn2d from torch_utils.ops import grid_sample_gradfix from torch_utils.ops import conv2d_gradfix from training.diffaug import DiffAugment from training.adaaug import AdaAugment #---------------------------------------------------------------------------- # Helpers for doing defusion process. def get_beta_schedule(beta_schedule, beta_start, beta_end, num_diffusion_timesteps): def sigmoid(x): return 1 / (np.exp(-x) + 1) def continuous_t_beta(t, T): b_max = 5. b_min = 0.1 alpha = np.exp(-b_min / T - 0.5 * (b_max - b_min) * (2 * t - 1) / T ** 2) return 1 - alpha if beta_schedule == "continuous_t": betas = continuous_t_beta(np.arange(1, num_diffusion_timesteps+1), num_diffusion_timesteps) elif beta_schedule == "quad": betas = ( np.linspace( beta_start ** 0.5, beta_end ** 0.5, num_diffusion_timesteps, dtype=np.float64, ) ** 2 ) elif beta_schedule == "linear": betas = np.linspace( beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64 ) elif beta_schedule == "const": betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 betas = 1.0 / np.linspace( num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 ) elif beta_schedule == "sigmoid": betas = np.linspace(-6, 6, num_diffusion_timesteps) betas = sigmoid(betas) * (beta_end - beta_start) + beta_start elif beta_schedule == 'cosine': """ cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ """ s = 0.008 steps = num_diffusion_timesteps + 1 x = np.linspace(0, steps, steps) alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2 alphas_cumprod = alphas_cumprod / alphas_cumprod[0] betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1]) betas_clipped = np.clip(betas, a_min=0, a_max=0.999) return betas_clipped else: raise NotImplementedError(beta_schedule) assert betas.shape == (num_diffusion_timesteps,) return betas def q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type='gauss', noise_std=1.0): if noise_type == 'gauss': noise = torch.randn_like(x_0, device=x_0.device) * noise_std elif noise_type == 'bernoulli': noise = (torch.bernoulli(torch.ones_like(x_0) * 0.5) * 2 - 1.) * noise_std else: raise NotImplementedError(noise_type) alphas_t_sqrt = alphas_bar_sqrt[t].view(-1, 1, 1, 1) one_minus_alphas_bar_t_sqrt = one_minus_alphas_bar_sqrt[t].view(-1, 1, 1, 1) x_t = alphas_t_sqrt * x_0 + one_minus_alphas_bar_t_sqrt * noise return x_t def q_sample_c(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type='gauss', noise_std=1.0): batch_size, num_channels, _, _ = x_0.shape if noise_type == 'gauss': noise = torch.randn_like(x_0, device=x_0.device) * noise_std elif noise_type == 'bernoulli': noise = (torch.bernoulli(torch.ones_like(x_0) * 0.5) * 2 - 1.) * noise_std else: raise NotImplementedError(noise_type) alphas_t_sqrt = alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1) one_minus_alphas_bar_t_sqrt = one_minus_alphas_bar_sqrt[t].view(batch_size, num_channels, 1, 1) x_t = alphas_t_sqrt * x_0 + one_minus_alphas_bar_t_sqrt * noise return x_t class Identity(torch.nn.Module): def __init__(self): super(Identity, self).__init__() def forward(self, x): return x @persistence.persistent_class class AugmentPipe(torch.nn.Module): def __init__(self, beta_schedule='linear', beta_start=1e-4, beta_end=2e-2, t_min=10, t_max=1000, noise_std=0.05, aug='NO', ada_maxp=None, ts_dist='priority', update_beta=True, ): super().__init__() self.p = 0.0 # Overall multiplier for augmentation probability. self.aug_type = aug self.ada_maxp = ada_maxp self.noise_type = self.base_noise_type = 'gauss' self.beta_schedule = beta_schedule self.beta_start = beta_start self.beta_end = beta_end self.t_min = t_min self.t_max = t_max self.t_add = int(t_max - t_min) self.ts_dist = ts_dist # Image-space corruptions. self.noise_std = float(noise_std) # Standard deviation of additive RGB noise. self.noise_type = "gauss" if aug == 'ADA': self.aug = AdaAugment(p=0.0) elif aug == 'DIFF': self.aug = DiffAugment() else: self.aug = Identity() self.update_beta = update_beta if not update_beta: self.set_diffusion_process(t_max, beta_schedule) self.update_T() def set_diffusion_process(self, t, beta_schedule): betas = get_beta_schedule( beta_schedule=beta_schedule, beta_start=self.beta_start, beta_end=self.beta_end, num_diffusion_timesteps=t, ) betas = self.betas = torch.from_numpy(betas).float() self.num_timesteps = betas.shape[0] alphas = self.alphas = 1.0 - betas alphas_cumprod = torch.cat([torch.tensor([1.]), alphas.cumprod(dim=0)]) self.alphas_bar_sqrt = torch.sqrt(alphas_cumprod) self.one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_cumprod) def update_T(self): if self.aug_type == 'ADA': _p = min(self.p, self.ada_maxp) if self.ada_maxp else self.p self.aug.p.copy_(torch.tensor(_p)) t_adjust = round(self.p * self.t_add) t = np.clip(int(self.t_min + t_adjust), a_min=self.t_min, a_max=self.t_max) if self.update_beta: if self.beta_schedule == 'linear_cosine': if t >= 500: self.set_diffusion_process(t, 'cosine') else: self.set_diffusion_process(t, 'linear') else: self.set_diffusion_process(t, self.beta_schedule) # sampling t self.t_epl = np.zeros(64, dtype=np.int) diffusion_ind = 32 t_diffusion = np.zeros((diffusion_ind,)).astype(np.int) if self.ts_dist == 'priority': prob_t = np.arange(t) / np.arange(t).sum() t_diffusion = np.random.choice(np.arange(1, t + 1), size=diffusion_ind, p=prob_t) elif self.ts_dist == 'uniform': t_diffusion = np.random.choice(np.arange(1, t + 1), size=diffusion_ind) self.t_epl[:diffusion_ind] = t_diffusion def forward(self, x_0): x_0 = self.aug(x_0) assert isinstance(x_0, torch.Tensor) and x_0.ndim == 4 batch_size, num_channels, height, width = x_0.shape device = x_0.device alphas_bar_sqrt = self.alphas_bar_sqrt.to(device) one_minus_alphas_bar_sqrt = self.one_minus_alphas_bar_sqrt.to(device) t = torch.from_numpy(np.random.choice(self.t_epl, size=batch_size, replace=True)).to(device) x_t = q_sample(x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, t, noise_type=self.noise_type, noise_std=self.noise_std) # x_t = self.aug(x_t) return x_t, t.view(-1, 1) #----------------------------------------------------------------------------