import os import torch import torch.nn.functional as F from torch.utils.tensorboard import SummaryWriter class DNO(object): def __init__( self, optimize: bool, max_train_steps: int, learning_rate: float, lr_scheduler: str, lr_warmup_steps: int, clip_grad: bool, loss_hint_type: str, loss_diff_penalty: float, loss_correlate_penalty: float, visualize_samples: int, visualize_ske_steps: list[int], output_dir: str ) -> None: self.optimize = optimize self.max_train_steps = max_train_steps self.learning_rate = learning_rate self.lr_scheduler = lr_scheduler self.lr_warmup_steps = lr_warmup_steps self.clip_grad = clip_grad self.loss_hint_type = loss_hint_type self.loss_diff_penalty = loss_diff_penalty self.loss_correlate_penalty = loss_correlate_penalty if loss_hint_type == 'l1': self.loss_hint_func = F.l1_loss elif loss_hint_type == 'l1_smooth': self.loss_hint_func = F.smooth_l1_loss elif loss_hint_type == 'l2': self.loss_hint_func = F.mse_loss else: raise ValueError(f'Invalid loss type: {loss_hint_type}') self.visualize_samples = float('inf') if visualize_samples == 'inf' else visualize_samples assert self.visualize_samples >= 0 self.visualize_samples_done = 0 self.visualize_ske_steps = visualize_ske_steps if len(visualize_ske_steps) > 0: self.vis_dir = os.path.join(output_dir, 'vis_optimize') os.makedirs(self.vis_dir) self.writer = None self.output_dir = output_dir if self.visualize_samples > 0: self.writer = SummaryWriter(output_dir) @property def do_visualize(self): return self.visualize_samples_done < self.visualize_samples @staticmethod def noise_regularize_1d(noise: torch.Tensor, stop_at: int = 2, dim: int = 1) -> torch.Tensor: size = noise.shape[dim] if size & (size - 1) != 0: new_size = 2 ** (size - 1).bit_length() pad = new_size - size pad_shape = list(noise.shape) pad_shape[dim] = pad pad_noise = torch.randn(*pad_shape, device=noise.device) noise = torch.cat([noise, pad_noise], dim=dim) size = noise.shape[dim] loss = torch.zeros(noise.shape[0], device=noise.device) while size > stop_at: rolled_noise = torch.roll(noise, shifts=1, dims=dim) loss += (noise * rolled_noise).mean(dim=tuple(range(1, noise.ndim))).pow(2) noise = noise.view(*noise.shape[:dim], size // 2, 2, *noise.shape[dim + 1:]).mean(dim=dim + 1) size //= 2 return loss