Spaces:
Runtime error
Runtime error
File size: 3,784 Bytes
193c713 921e5ff 193c713 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import torch
import torch.nn.functional as F
from tqdm import tqdm
from sam_diffsr.utils_sr.hparams import hparams
from .diffusion import GaussianDiffusion, noise_like, extract
from .module_util import default
class GaussianDiffusion_sam(GaussianDiffusion):
def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1', sam_config=None):
super().__init__(denoise_fn, rrdb_net, timesteps, loss_type)
self.sam_config = sam_config
def p_losses(self, x_start, t, cond, img_lr_up, noise=None, sam_mask=None):
noise = default(noise, lambda: torch.randn_like(x_start))
if self.sam_config['p_losses_sam']:
_sam_mask = F.interpolate(sam_mask, noise.shape[2:], mode='bilinear')
if self.sam_config.get('mask_coefficient', False):
_sam_mask *= extract(self.mask_coefficient.to(_sam_mask.device), t, x_start.shape)
noise += _sam_mask
x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up, sam_mask=sam_mask)
x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred, sam_mask=sam_mask)
if self.loss_type == 'l1':
loss = (noise - noise_pred).abs().mean()
elif self.loss_type == 'l2':
loss = F.mse_loss(noise, noise_pred)
elif self.loss_type == 'ssim':
loss = (noise - noise_pred).abs().mean()
loss = loss + (1 - self.ssim_loss(noise, noise_pred))
else:
raise NotImplementedError()
return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred
@torch.no_grad()
def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False, sam_mask=None):
if noise_pred is None:
noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up, sam_mask=sam_mask)
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance, x0_pred = self.p_mean_variance(
x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised)
noise = noise_like(x.shape, device, repeat_noise)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred
@torch.no_grad()
def sample(self, img_lr, img_lr_up, shape, sam_mask=None, save_intermediate=False):
device = self.betas.device
b = shape[0]
if not hparams['res']:
t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long)
noise = None
img = self.q_sample(img_lr_up, t, noise=noise)
else:
img = torch.randn(shape, device=device)
if hparams['use_rrdb']:
rrdb_out, cond = self.rrdb(img_lr, True)
else:
rrdb_out = img_lr_up
cond = img_lr
it = reversed(range(0, self.num_timesteps))
if self.sample_tqdm:
it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps)
images = []
for i in it:
img, x_recon = self.p_sample(
img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up, sam_mask=sam_mask)
if save_intermediate:
img_ = self.res2img(img, img_lr_up)
x_recon_ = self.res2img(x_recon, img_lr_up)
images.append((img_.cpu(), x_recon_.cpu()))
img = self.res2img(img, img_lr_up)
if save_intermediate:
return img, rrdb_out, images
else:
return img, rrdb_out
|