import torch
import torch.nn.functional as F
from tqdm import tqdm

from sam_diffsr.utils_sr.hparams import hparams
from .diffusion import GaussianDiffusion, noise_like, extract
from .module_util import default


class GaussianDiffusion_sam(GaussianDiffusion):
    def __init__(self, denoise_fn, rrdb_net, timesteps=1000, loss_type='l1', sam_config=None):
        super().__init__(denoise_fn, rrdb_net, timesteps, loss_type)
        self.sam_config = sam_config

    def p_losses(self, x_start, t, cond, img_lr_up, noise=None, sam_mask=None):
        noise = default(noise, lambda: torch.randn_like(x_start))

        if self.sam_config['p_losses_sam']:
            _sam_mask = F.interpolate(sam_mask, noise.shape[2:], mode='bilinear')
            if self.sam_config.get('mask_coefficient', False):
                _sam_mask *= extract(self.mask_coefficient.to(_sam_mask.device), t, x_start.shape)
            noise += _sam_mask

        x_tp1_gt = self.q_sample(x_start=x_start, t=t, noise=noise)
        x_t_gt = self.q_sample(x_start=x_start, t=t - 1, noise=noise)
        noise_pred = self.denoise_fn(x_tp1_gt, t, cond, img_lr_up, sam_mask=sam_mask)
        x_t_pred, x0_pred = self.p_sample(x_tp1_gt, t, cond, img_lr_up, noise_pred=noise_pred, sam_mask=sam_mask)

        if self.loss_type == 'l1':
            loss = (noise - noise_pred).abs().mean()
        elif self.loss_type == 'l2':
            loss = F.mse_loss(noise, noise_pred)
        elif self.loss_type == 'ssim':
            loss = (noise - noise_pred).abs().mean()
            loss = loss + (1 - self.ssim_loss(noise, noise_pred))
        else:
            raise NotImplementedError()
        return loss, x_tp1_gt, noise_pred, x_t_pred, x_t_gt, x0_pred

    @torch.no_grad()
    def p_sample(self, x, t, cond, img_lr_up, noise_pred=None, clip_denoised=True, repeat_noise=False, sam_mask=None):
        if noise_pred is None:
            noise_pred = self.denoise_fn(x, t, cond=cond, img_lr_up=img_lr_up, sam_mask=sam_mask)
        b, *_, device = *x.shape, x.device
        model_mean, _, model_log_variance, x0_pred = self.p_mean_variance(
            x=x, t=t, noise_pred=noise_pred, clip_denoised=clip_denoised)

        noise = noise_like(x.shape, device, repeat_noise)

        # no noise when t == 0
        nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
        return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, x0_pred

    @torch.no_grad()
    def sample(self, img_lr, img_lr_up, shape, sam_mask=None, save_intermediate=False):
        device = self.betas.device
        b = shape[0]

        if not hparams['res']:
            t = torch.full((b,), self.num_timesteps - 1, device=device, dtype=torch.long)
            noise = None
            img = self.q_sample(img_lr_up, t, noise=noise)
        else:
            img = torch.randn(shape, device=device)

        if hparams['use_rrdb']:
            rrdb_out, cond = self.rrdb(img_lr, True)
        else:
            rrdb_out = img_lr_up
            cond = img_lr

        it = reversed(range(0, self.num_timesteps))

        if self.sample_tqdm:
            it = tqdm(it, desc='sampling loop time step', total=self.num_timesteps)

        images = []
        for i in it:
            img, x_recon = self.p_sample(
                img, torch.full((b,), i, device=device, dtype=torch.long), cond, img_lr_up, sam_mask=sam_mask)
            if save_intermediate:
                img_ = self.res2img(img, img_lr_up)
                x_recon_ = self.res2img(x_recon, img_lr_up)
                images.append((img_.cpu(), x_recon_.cpu()))
        img = self.res2img(img, img_lr_up)

        if save_intermediate:
            return img, rrdb_out, images
        else:
            return img, rrdb_out