"""
GaussianDiffusion wraps operators for denoising diffusion models, including the
diffusion and denoising processes, as well as the loss evaluation.
"""
import torch
import torchsde
import random
from tqdm.auto import trange


__all__ = ['GaussianDiffusion']


def _i(tensor, t, x):
    """
    Index tensor using t and format the output according to x.
    """
    shape = (x.size(0), ) + (1, ) * (x.ndim - 1)
    return tensor[t.to(tensor.device)].view(shape).to(x.device)


class BatchedBrownianTree:
    """
    A wrapper around torchsde.BrownianTree that enables batches of entropy.
    """
    def __init__(self, x, t0, t1, seed=None, **kwargs):
        t0, t1, self.sign = self.sort(t0, t1)
        w0 = kwargs.get('w0', torch.zeros_like(x))
        if seed is None:
            seed = torch.randint(0, 2 ** 63 - 1, []).item()
        self.batched = True
        try:
            assert len(seed) == x.shape[0]
            w0 = w0[0]
        except TypeError:
            seed = [seed]
            self.batched = False
        self.trees = [torchsde.BrownianTree(
            t0, w0, t1, entropy=s, **kwargs
        ) for s in seed]
    
    @staticmethod
    def sort(a, b):
        return (a, b, 1) if a < b else (b, a, -1)

    def __call__(self, t0, t1):
        t0, t1, sign = self.sort(t0, t1)
        w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
        return w if self.batched else w[0]


class BrownianTreeNoiseSampler:
    """
    A noise sampler backed by a torchsde.BrownianTree.

    Args:
        x (Tensor): The tensor whose shape, device and dtype to use to generate
            random samples.
        sigma_min (float): The low end of the valid interval.
        sigma_max (float): The high end of the valid interval.
        seed (int or List[int]): The random seed. If a list of seeds is
            supplied instead of a single integer, then the noise sampler will
            use one BrownianTree per batch item, each with its own seed.
        transform (callable): A function that maps sigma to the sampler's
            internal timestep.
    """
    def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x):
        self.transform = transform
        t0 = self.transform(torch.as_tensor(sigma_min))
        t1 = self.transform(torch.as_tensor(sigma_max))
        self.tree = BatchedBrownianTree(x, t0, t1, seed)
    
    def __call__(self, sigma, sigma_next):
        t0 = self.transform(torch.as_tensor(sigma))
        t1 = self.transform(torch.as_tensor(sigma_next))
        return self.tree(t0, t1) / (t1 - t0).abs().sqrt()


def get_scalings(sigma):
    c_out = -sigma
    c_in = 1 / (sigma ** 2 + 1. ** 2) ** 0.5
    return c_out, c_in


@torch.no_grad()
def sample_dpmpp_2m_sde(
    noise,
    model,
    sigmas,
    eta=1.,
    s_noise=1.,
    solver_type='midpoint',
    show_progress=True
):
    """
    DPM-Solver++ (2M) SDE.
    """
    assert solver_type in {'heun', 'midpoint'}

    x = noise * sigmas[0]
    sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas[sigmas < float('inf')].max()
    noise_sampler = BrownianTreeNoiseSampler(x, sigma_min, sigma_max)
    old_denoised = None
    h_last = None

    for i in trange(len(sigmas) - 1, disable=not show_progress):
        if sigmas[i] == float('inf'):
            # Euler method
            denoised = model(noise, sigmas[i])
            x = denoised + sigmas[i + 1] * noise
        else:
            _, c_in = get_scalings(sigmas[i])
            denoised = model(x * c_in, sigmas[i])
            if sigmas[i + 1] == 0:
                # Denoising step
                x = denoised
            else:
                # DPM-Solver++(2M) SDE
                t, s = -sigmas[i].log(), -sigmas[i + 1].log()
                h = s - t
                eta_h = eta * h

                x = sigmas[i + 1] / sigmas[i] * (-eta_h).exp() * x + \
                    (-h - eta_h).expm1().neg() * denoised

                if old_denoised is not None:
                    r = h_last / h
                    if solver_type == 'heun':
                        x = x + ((-h - eta_h).expm1().neg() / (-h - eta_h) + 1) * \
                            (1 / r) * (denoised - old_denoised)
                    elif solver_type == 'midpoint':
                        x = x + 0.5 * (-h - eta_h).expm1().neg() * \
                            (1 / r) * (denoised - old_denoised)

                x = x + noise_sampler(
                    sigmas[i],
                    sigmas[i + 1]
                ) * sigmas[i + 1] * (-2 * eta_h).expm1().neg().sqrt() * s_noise

            old_denoised = denoised
            h_last = h
    return x


class GaussianDiffusion(object):

    def __init__(self, sigmas, prediction_type='eps'):
        assert prediction_type in {'x0', 'eps', 'v'}
        self.sigmas = sigmas.float()                        # noise coefficients
        self.alphas = torch.sqrt(1 - sigmas ** 2).float()   # signal coefficients
        self.num_timesteps = len(sigmas)
        self.prediction_type = prediction_type

    def diffuse(self, x0, t, noise=None):
        """
        Add Gaussian noise to signal x0 according to:
        q(x_t | x_0) = N(x_t | alpha_t x_0, sigma_t^2 I).
        """
        noise = torch.randn_like(x0) if noise is None else noise
        xt = _i(self.alphas, t, x0) * x0 + _i(self.sigmas, t, x0) * noise
        return xt
    
    def denoise(
        self,
        xt,
        t,
        s,
        model,
        model_kwargs={},
        guide_scale=None,
        guide_rescale=None,
        clamp=None,
        percentile=None
    ):
        """
        Apply one step of denoising from the posterior distribution q(x_s | x_t, x0).
        Since x0 is not available, estimate the denoising results using the learned
        distribution p(x_s | x_t, \hat{x}_0 == f(x_t)).
        """
        s = t - 1 if s is None else s

        # hyperparams
        sigmas = _i(self.sigmas, t, xt)
        alphas = _i(self.alphas, t, xt)
        alphas_s = _i(self.alphas, s.clamp(0), xt)
        alphas_s[s < 0] = 1.
        sigmas_s = torch.sqrt(1 - alphas_s ** 2)

        # precompute variables
        betas = 1 - (alphas / alphas_s) ** 2
        coef1 = betas * alphas_s / sigmas ** 2
        coef2 = (alphas * sigmas_s ** 2) / (alphas_s * sigmas ** 2)
        var = betas * (sigmas_s / sigmas) ** 2
        log_var = torch.log(var).clamp_(-20, 20)

        # prediction
        if guide_scale is None:
            assert isinstance(model_kwargs, dict)
            out = model(xt, t=t, **model_kwargs)
        else:
            # classifier-free guidance (arXiv:2207.12598)
            # model_kwargs[0]: conditional kwargs
            # model_kwargs[1]: non-conditional kwargs
            assert isinstance(model_kwargs, list) and len(model_kwargs) == 2
            y_out = model(xt, t=t, **model_kwargs[0])
            if guide_scale == 1.:
                out = y_out
            else:
                u_out = model(xt, t=t, **model_kwargs[1])
                out = u_out + guide_scale * (y_out - u_out)

                # rescale the output according to arXiv:2305.08891
                if guide_rescale is not None:
                    assert guide_rescale >= 0 and guide_rescale <= 1
                    ratio = (y_out.flatten(1).std(dim=1) / (
                        out.flatten(1).std(dim=1) + 1e-12
                    )).view((-1, ) + (1, ) * (y_out.ndim - 1))
                    out *= guide_rescale * ratio + (1 - guide_rescale) * 1.0
        
        # compute x0
        if self.prediction_type == 'x0':
            x0 = out
        elif self.prediction_type == 'eps':
            x0 = (xt - sigmas * out) / alphas
        elif self.prediction_type == 'v':
            x0 = alphas * xt - sigmas * out
        else:
            raise NotImplementedError(
                f'prediction_type {self.prediction_type} not implemented'
            )
        
        # restrict the range of x0
        if percentile is not None:
            # NOTE: percentile should only be used when data is within range [-1, 1]
            assert percentile > 0 and percentile <= 1
            s = torch.quantile(x0.flatten(1).abs(), percentile, dim=1)
            s = s.clamp_(1.0).view((-1, ) + (1, ) * (xt.ndim - 1))
            x0 = torch.min(s, torch.max(-s, x0)) / s
        elif clamp is not None:
            x0 = x0.clamp(-clamp, clamp)
        
        # recompute eps using the restricted x0
        eps = (xt - alphas * x0) / sigmas

        # compute mu (mean of posterior distribution) using the restricted x0
        mu = coef1 * x0 + coef2 * xt
        return mu, var, log_var, x0, eps

    @torch.no_grad()
    def sample(
        self,
        noise,
        model,
        model_kwargs={},
        condition_fn=None,
        guide_scale=None,
        guide_rescale=None,
        clamp=None,
        percentile=None,
        solver='euler_a',
        steps=20,
        t_max=None,
        t_min=None,
        discretization=None,
        discard_penultimate_step=None,
        return_intermediate=None,
        show_progress=False,
        seed=-1,
        **kwargs
    ):
        # sanity check
        assert isinstance(steps, (int, torch.LongTensor))
        assert t_max is None or (t_max > 0 and t_max <= self.num_timesteps - 1)
        assert t_min is None or (t_min >= 0 and t_min < self.num_timesteps - 1)
        assert discretization in (None, 'leading', 'linspace', 'trailing')
        assert discard_penultimate_step in (None, True, False)
        assert return_intermediate in (None, 'x0', 'xt')

        # function of diffusion solver
        solver_fn = {
            # 'heun': sample_heun,
            'dpmpp_2m_sde': sample_dpmpp_2m_sde
        }[solver]

        # options
        schedule = 'karras' if 'karras' in solver else None
        discretization = discretization or 'linspace'
        seed = seed if seed >= 0 else random.randint(0, 2 ** 31)
        if isinstance(steps, torch.LongTensor):
            discard_penultimate_step = False
        if discard_penultimate_step is None:
            discard_penultimate_step = True if solver in (
                'dpm2',
                'dpm2_ancestral',
                'dpmpp_2m_sde',
                'dpm2_karras',
                'dpm2_ancestral_karras',
                'dpmpp_2m_sde_karras'
            ) else False
        
        # function for denoising xt to get x0
        intermediates = []
        def model_fn(xt, sigma):
            # denoising
            t = self._sigma_to_t(sigma).repeat(len(xt)).round().long()
            x0 = self.denoise(
                xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp,
                percentile
            )[-2]

            # collect intermediate outputs
            if return_intermediate == 'xt':
                intermediates.append(xt)
            elif return_intermediate == 'x0':
                intermediates.append(x0)
            return x0
        
        # get timesteps
        if isinstance(steps, int):
            steps += 1 if discard_penultimate_step else 0
            t_max = self.num_timesteps - 1 if t_max is None else t_max
            t_min = 0 if t_min is None else t_min

            # discretize timesteps
            if discretization == 'leading':
                steps = torch.arange(
                    t_min, t_max + 1, (t_max - t_min + 1) / steps
                ).flip(0)
            elif discretization == 'linspace':
                steps = torch.linspace(t_max, t_min, steps)
            elif discretization == 'trailing':
                steps = torch.arange(t_max, t_min - 1, -((t_max - t_min + 1) / steps))
            else:
                raise NotImplementedError(
                    f'{discretization} discretization not implemented'
                )
            steps = steps.clamp_(t_min, t_max)
        steps = torch.as_tensor(steps, dtype=torch.float32, device=noise.device)

        # get sigmas
        sigmas = self._t_to_sigma(steps)
        sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
        if schedule == 'karras':
            if sigmas[0] == float('inf'):
                sigmas = karras_schedule(
                    n=len(steps) - 1,
                    sigma_min=sigmas[sigmas > 0].min().item(),
                    sigma_max=sigmas[sigmas < float('inf')].max().item(),
                    rho=7.
                ).to(sigmas)
                sigmas = torch.cat([
                    sigmas.new_tensor([float('inf')]), sigmas, sigmas.new_zeros([1])
                ])
            else:
                sigmas = karras_schedule(
                    n=len(steps),
                    sigma_min=sigmas[sigmas > 0].min().item(),
                    sigma_max=sigmas.max().item(),
                    rho=7.
                ).to(sigmas)
                sigmas = torch.cat([sigmas, sigmas.new_zeros([1])])
        if discard_penultimate_step:
            sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])
        
        # sampling
        x0 = solver_fn(
            noise,
            model_fn,
            sigmas,
            show_progress=show_progress,
            **kwargs
        )
        return (x0, intermediates) if return_intermediate is not None else x0
    
    @torch.no_grad()
    def ddim_reverse_sample(
        self,
        xt,
        t,
        model,
        model_kwargs={},
        clamp=None,
        percentile=None,
        guide_scale=None,
        guide_rescale=None,
        ddim_timesteps=20,
        reverse_steps=600
        ):
        r"""Sample from p(x_{t+1} | x_t) using DDIM reverse ODE (deterministic).
        """
        stride = reverse_steps // ddim_timesteps

        # predict distribution of p(x_{t-1} | x_t)
        _, _, _, x0, eps = self.denoise(
                xt, t, None, model, model_kwargs, guide_scale, guide_rescale, clamp,
                percentile
            )
        # derive variables
        s = (t + stride).clamp(0, reverse_steps-1)
        # hyperparams
        sigmas = _i(self.sigmas, t, xt)
        alphas = _i(self.alphas, t, xt)
        alphas_s = _i(self.alphas, s.clamp(0), xt)
        alphas_s[s < 0] = 1.
        sigmas_s = torch.sqrt(1 - alphas_s ** 2)
        
        # reverse sample
        mu = alphas_s * x0 + sigmas_s * eps
        return mu, x0
    
    @torch.no_grad()
    def ddim_reverse_sample_loop(
        self,
        x0,
        model,
        model_kwargs={},
        clamp=None,
        percentile=None,
        guide_scale=None,
        guide_rescale=None,
        ddim_timesteps=20,
        reverse_steps=600
        ):
        # prepare input
        b = x0.size(0)
        xt = x0

        # reconstruction steps
        steps = torch.arange(0, reverse_steps, reverse_steps // ddim_timesteps)
        for step in steps:
            t = torch.full((b, ), step, dtype=torch.long, device=xt.device)
            xt, _ = self.ddim_reverse_sample(xt, t, model, model_kwargs, clamp, percentile, guide_scale, guide_rescale, ddim_timesteps, reverse_steps)
        return xt
    
    def _sigma_to_t(self, sigma):
        if sigma == float('inf'):
            t = torch.full_like(sigma, len(self.sigmas) - 1)
        else:
            log_sigmas = torch.sqrt(
                self.sigmas ** 2 / (1 - self.sigmas ** 2)
            ).log().to(sigma)
            log_sigma = sigma.log()
            dists = log_sigma - log_sigmas[:, None]
            low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(
                max=log_sigmas.shape[0] - 2
            )
            high_idx = low_idx + 1
            low, high = log_sigmas[low_idx], log_sigmas[high_idx]
            w = (low - log_sigma) / (low - high)
            w = w.clamp(0, 1)
            t = (1 - w) * low_idx + w * high_idx
            t = t.view(sigma.shape)
        if t.ndim == 0:
            t = t.unsqueeze(0)
        return t

    def _t_to_sigma(self, t):
        t = t.float()
        low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac()
        log_sigmas = torch.sqrt(self.sigmas ** 2 / (1 - self.sigmas ** 2)).log().to(t)
        log_sigma = (1 - w) * log_sigmas[low_idx] + w * log_sigmas[high_idx]
        log_sigma[torch.isnan(log_sigma) | torch.isinf(log_sigma)] = float('inf')
        return log_sigma.exp()
    
    def prev_step(self, model_out, t, xt, inference_steps=50):
        prev_t = t - self.num_timesteps // inference_steps

        sigmas = _i(self.sigmas, t, xt)
        alphas = _i(self.alphas, t, xt)
        alphas_prev = _i(self.alphas, prev_t.clamp(0), xt)
        alphas_prev[prev_t < 0] = 1.
        sigmas_prev = torch.sqrt(1 - alphas_prev ** 2)
        
        x0 = alphas * xt - sigmas * model_out
        eps = (xt - alphas * x0) / sigmas
        prev_sample = alphas_prev * x0 + sigmas_prev * eps
        return prev_sample
    
    def next_step(self, model_out, t, xt, inference_steps=50):
        t, next_t = min(t - self.num_timesteps // inference_steps, 999), t

        sigmas = _i(self.sigmas, t, xt)
        alphas = _i(self.alphas, t, xt)
        alphas_next = _i(self.alphas, next_t.clamp(0), xt)
        alphas_next[next_t < 0] = 1.
        sigmas_next = torch.sqrt(1 - alphas_next ** 2)
        
        x0 = alphas * xt - sigmas * model_out
        eps = (xt - alphas * x0) / sigmas
        next_sample = alphas_next * x0 + sigmas_next * eps
        return next_sample
    
    def get_noise_pred_single(self, xt, t, model, model_kwargs):
        assert isinstance(model_kwargs, dict)
        out = model(xt, t=t, **model_kwargs)
        return out