# everything that can improve v-prediction model
# dynamic scaling + tsnr + beta modifier + dynamic cfg rescale + ...
# written by lvmin at stanford 2024

import torch
import numpy as np

from tqdm import tqdm
from functools import partial
from diffusers_vdm.basics import extract_into_tensor


to_torch = partial(torch.tensor, dtype=torch.float32)


def rescale_zero_terminal_snr(betas):
    # Convert betas to alphas_bar_sqrt
    alphas = 1.0 - betas
    alphas_cumprod = np.cumprod(alphas, axis=0)
    alphas_bar_sqrt = np.sqrt(alphas_cumprod)

    # Store old values.
    alphas_bar_sqrt_0 = alphas_bar_sqrt[0].copy()
    alphas_bar_sqrt_T = alphas_bar_sqrt[-1].copy()

    # Shift so the last timestep is zero.
    alphas_bar_sqrt -= alphas_bar_sqrt_T

    # Scale so the first timestep is back to the old value.
    alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)

    # Convert alphas_bar_sqrt to betas
    alphas_bar = alphas_bar_sqrt**2  # Revert sqrt
    alphas = alphas_bar[1:] / alphas_bar[:-1]  # Revert cumprod
    alphas = np.concatenate([alphas_bar[0:1], alphas])
    betas = 1 - alphas

    return betas


def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
    std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
    std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)

    # rescale the results from guidance (fixes overexposure)
    noise_pred_rescaled = noise_cfg * (std_text / std_cfg)

    # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
    noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg

    return noise_cfg


class SamplerDynamicTSNR(torch.nn.Module):
    @torch.no_grad()
    def __init__(self, unet, terminal_scale=0.7):
        super().__init__()
        self.unet = unet

        self.is_v = True
        self.n_timestep = 1000
        self.guidance_rescale = 0.7

        linear_start = 0.00085
        linear_end = 0.012

        betas = np.linspace(linear_start ** 0.5, linear_end ** 0.5, self.n_timestep, dtype=np.float64) ** 2
        betas = rescale_zero_terminal_snr(betas)
        alphas = 1. - betas

        alphas_cumprod = np.cumprod(alphas, axis=0)

        self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod).to(unet.device))
        self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)).to(unet.device))
        self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)).to(unet.device))

        # Dynamic TSNR
        turning_step = 400
        scale_arr = np.concatenate([
            np.linspace(1.0, terminal_scale, turning_step),
            np.full(self.n_timestep - turning_step, terminal_scale)
        ])
        self.register_buffer('scale_arr', to_torch(scale_arr).to(unet.device))

    def predict_eps_from_z_and_v(self, x_t, t, v):
        return self.sqrt_alphas_cumprod[t] * v + self.sqrt_one_minus_alphas_cumprod[t] * x_t

    def predict_start_from_z_and_v(self, x_t, t, v):
        return self.sqrt_alphas_cumprod[t] * x_t - self.sqrt_one_minus_alphas_cumprod[t] * v

    def q_sample(self, x0, t, noise):
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * x0 +
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * noise)

    def get_v(self, x0, t, noise):
        return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x0.shape) * noise -
                extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x0.shape) * x0)

    def dynamic_x0_rescale(self, x0, t):
        return x0 * extract_into_tensor(self.scale_arr, t, x0.shape)

    @torch.no_grad()
    def get_ground_truth(self, x0, noise, t):
        x0 = self.dynamic_x0_rescale(x0, t)
        xt = self.q_sample(x0, t, noise)
        target = self.get_v(x0, t, noise) if self.is_v else noise
        return xt, target

    def get_uniform_trailing_steps(self, steps):
        c = self.n_timestep / steps
        ddim_timesteps = np.flip(np.round(np.arange(self.n_timestep, 0, -c))).astype(np.int64)
        steps_out = ddim_timesteps - 1
        return torch.tensor(steps_out, device=self.unet.device, dtype=torch.long)

    @torch.no_grad()
    def forward(self, latent_shape, steps, extra_args, progress_tqdm=None):
        bar = tqdm if progress_tqdm is None else progress_tqdm

        eta = 1.0

        timesteps = self.get_uniform_trailing_steps(steps)
        timesteps_prev = torch.nn.functional.pad(timesteps[:-1], pad=(1, 0))

        x = torch.randn(latent_shape, device=self.unet.device, dtype=self.unet.dtype)

        alphas = self.alphas_cumprod[timesteps]
        alphas_prev = self.alphas_cumprod[timesteps_prev]
        scale_arr = self.scale_arr[timesteps]
        scale_arr_prev = self.scale_arr[timesteps_prev]

        sqrt_one_minus_alphas = torch.sqrt(1 - alphas)
        sigmas = eta * np.sqrt((1 - alphas_prev.cpu().numpy()) / (1 - alphas.cpu()) * (1 - alphas.cpu() / alphas_prev.cpu().numpy()))

        s_in = x.new_ones((x.shape[0]))
        s_x = x.new_ones((x.shape[0], ) + (1, ) * (x.ndim - 1))
        for i in bar(range(len(timesteps))):
            index = len(timesteps) - 1 - i
            t = timesteps[index].item()

            model_output = self.model_apply(x, t * s_in, **extra_args)

            if self.is_v:
                e_t = self.predict_eps_from_z_and_v(x, t, model_output)
            else:
                e_t = model_output

            a_prev = alphas_prev[index].item() * s_x
            sigma_t = sigmas[index].item() * s_x

            if self.is_v:
                pred_x0 = self.predict_start_from_z_and_v(x, t, model_output)
            else:
                a_t = alphas[index].item() * s_x
                sqrt_one_minus_at = sqrt_one_minus_alphas[index].item() * s_x
                pred_x0 = (x - sqrt_one_minus_at * e_t) / a_t.sqrt()

            # dynamic rescale
            scale_t = scale_arr[index].item() * s_x
            prev_scale_t = scale_arr_prev[index].item() * s_x
            rescale = (prev_scale_t / scale_t)
            pred_x0 = pred_x0 * rescale

            dir_xt = (1. - a_prev - sigma_t ** 2).sqrt() * e_t
            noise = sigma_t * torch.randn_like(x)
            x = a_prev.sqrt() * pred_x0 + dir_xt + noise

        return x

    @torch.no_grad()
    def model_apply(self, x, t, **extra_args):
        x = x.to(device=self.unet.device, dtype=self.unet.dtype)
        cfg_scale = extra_args['cfg_scale']
        p = self.unet(x, t, **extra_args['positive'])
        n = self.unet(x, t, **extra_args['negative'])
        o = n + cfg_scale * (p - n)
        o_better = rescale_noise_cfg(o, p, guidance_rescale=self.guidance_rescale)
        return o_better