wxDai's picture
[Init]
eb339cb
import os
import torch
import torch.nn.functional as F
from torch.utils.tensorboard import SummaryWriter
class DNO(object):
def __init__(
self,
optimize: bool,
max_train_steps: int,
learning_rate: float,
lr_scheduler: str,
lr_warmup_steps: int,
clip_grad: bool,
loss_hint_type: str,
loss_diff_penalty: float,
loss_correlate_penalty: float,
visualize_samples: int,
visualize_ske_steps: list[int],
output_dir: str
) -> None:
self.optimize = optimize
self.max_train_steps = max_train_steps
self.learning_rate = learning_rate
self.lr_scheduler = lr_scheduler
self.lr_warmup_steps = lr_warmup_steps
self.clip_grad = clip_grad
self.loss_hint_type = loss_hint_type
self.loss_diff_penalty = loss_diff_penalty
self.loss_correlate_penalty = loss_correlate_penalty
if loss_hint_type == 'l1':
self.loss_hint_func = F.l1_loss
elif loss_hint_type == 'l1_smooth':
self.loss_hint_func = F.smooth_l1_loss
elif loss_hint_type == 'l2':
self.loss_hint_func = F.mse_loss
else:
raise ValueError(f'Invalid loss type: {loss_hint_type}')
self.visualize_samples = float('inf') if visualize_samples == 'inf' else visualize_samples
assert self.visualize_samples >= 0
self.visualize_samples_done = 0
self.visualize_ske_steps = visualize_ske_steps
if len(visualize_ske_steps) > 0:
self.vis_dir = os.path.join(output_dir, 'vis_optimize')
os.makedirs(self.vis_dir)
self.writer = None
self.output_dir = output_dir
if self.visualize_samples > 0:
self.writer = SummaryWriter(output_dir)
@property
def do_visualize(self):
return self.visualize_samples_done < self.visualize_samples
@staticmethod
def noise_regularize_1d(noise: torch.Tensor, stop_at: int = 2, dim: int = 1) -> torch.Tensor:
size = noise.shape[dim]
if size & (size - 1) != 0:
new_size = 2 ** (size - 1).bit_length()
pad = new_size - size
pad_shape = list(noise.shape)
pad_shape[dim] = pad
pad_noise = torch.randn(*pad_shape, device=noise.device)
noise = torch.cat([noise, pad_noise], dim=dim)
size = noise.shape[dim]
loss = torch.zeros(noise.shape[0], device=noise.device)
while size > stop_at:
rolled_noise = torch.roll(noise, shifts=1, dims=dim)
loss += (noise * rolled_noise).mean(dim=tuple(range(1, noise.ndim))).pow(2)
noise = noise.view(*noise.shape[:dim], size // 2, 2, *noise.shape[dim + 1:]).mean(dim=dim + 1)
size //= 2
return loss