import argparse
import json
import sys
from pathlib import Path

import k_diffusion
import numpy as np
import torch
import torch.nn as nn
from einops import rearrange, repeat
from omegaconf import OmegaConf
from PIL import Image
from pytorch_lightning import seed_everything
from tqdm import tqdm

sys.path.append("./")
sys.path.append("./stable_diffusion")

from ldm.modules.attention import CrossAttention, MemoryEfficientCrossAttention
from ldm.util import instantiate_from_config
from metrics.clip_similarity import ClipSimilarity


################################################################################
# Modified K-diffusion Euler ancestral sampler with prompt-to-prompt.
# https://github.com/crowsonkb/k-diffusion/blob/master/k_diffusion/sampling.py


def append_dims(x, target_dims):
    """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
    dims_to_append = target_dims - x.ndim
    if dims_to_append < 0:
        raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
    return x[(...,) + (None,) * dims_to_append]


def to_d(x, sigma, denoised):
    """Converts a denoiser output to a Karras ODE derivative."""
    return (x - denoised) / append_dims(sigma, x.ndim)


def get_ancestral_step(sigma_from, sigma_to):
    """Calculates the noise level (sigma_down) to step down to and the amount
    of noise to add (sigma_up) when doing an ancestral sampling step."""
    sigma_up = min(sigma_to, (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
    sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
    return sigma_down, sigma_up


def sample_euler_ancestral(model, x, sigmas, prompt2prompt_threshold=0.0, **extra_args):
    """Ancestral sampling with Euler method steps."""
    s_in = x.new_ones([x.shape[0]])
    for i in range(len(sigmas) - 1):
        prompt_to_prompt = prompt2prompt_threshold > i / (len(sigmas) - 2)
        for m in model.modules():
            if isinstance(m, CrossAttention) or isinstance(m, MemoryEfficientCrossAttention):
                m.prompt_to_prompt = prompt_to_prompt
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1])
        d = to_d(x, sigmas[i], denoised)
        # Euler method
        dt = sigma_down - sigmas[i]
        x = x + d * dt
        if sigmas[i + 1] > 0:
            # Make noise the same across all samples in batch.
            x = x + torch.randn_like(x[:1]) * sigma_up
    return x


################################################################################


def load_model_from_config(config, ckpt, vae_ckpt=None, verbose=False):
    print(f"Loading model from {ckpt}")
    pl_sd = torch.load(ckpt, map_location="cpu")
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")
    sd = pl_sd["state_dict"]
    if vae_ckpt is not None:
        print(f"Loading VAE from {vae_ckpt}")
        vae_sd = torch.load(vae_ckpt, map_location="cpu")["state_dict"]
        sd = {
            k: vae_sd[k[len("first_stage_model.") :]] if k.startswith("first_stage_model.") else v
            for k, v in sd.items()
        }
    model = instantiate_from_config(config.model)
    m, u = model.load_state_dict(sd, strict=False)
    if len(m) > 0 and verbose:
        print("missing keys:")
        print(m)
    if len(u) > 0 and verbose:
        print("unexpected keys:")
        print(u)
    return model


class CFGDenoiser(nn.Module):
    def __init__(self, model):
        super().__init__()
        self.inner_model = model

    def forward(self, x, sigma, uncond, cond, cfg_scale):
        x_in = torch.cat([x] * 2)
        sigma_in = torch.cat([sigma] * 2)
        cond_in = torch.cat([uncond, cond])
        uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2)
        return uncond + (cond - uncond) * cfg_scale


def to_pil(image: torch.Tensor) -> Image.Image:
    image = 255.0 * rearrange(image.cpu().numpy(), "c h w -> h w c")
    image = Image.fromarray(image.astype(np.uint8))
    return image


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--out_dir",
        type=str,
        required=True,
        help="Path to output dataset directory.",
    )
    parser.add_argument(
        "--prompts_file",
        type=str,
        required=True,
        help="Path to prompts .jsonl file.",
    )
    parser.add_argument(
        "--ckpt",
        type=str,
        default="stable_diffusion/models/ldm/stable-diffusion-v1/v1-5-pruned-emaonly.ckpt",
        help="Path to stable diffusion checkpoint.",
    )
    parser.add_argument(
        "--vae-ckpt",
        type=str,
        default="stable_diffusion/models/ldm/stable-diffusion-v1/vae-ft-mse-840000-ema-pruned.ckpt",
        help="Path to vae checkpoint.",
    )
    parser.add_argument(
        "--steps",
        type=int,
        default=100,
        help="Number of sampling steps.",
    )
    parser.add_argument(
        "--n-samples",
        type=int,
        default=100,
        help="Number of samples to generate per prompt (before CLIP filtering).",
    )
    parser.add_argument(
        "--max-out-samples",
        type=int,
        default=4,
        help="Max number of output samples to save per prompt (after CLIP filtering).",
    )
    parser.add_argument(
        "--n-partitions",
        type=int,
        default=1,
        help="Number of total partitions.",
    )
    parser.add_argument(
        "--partition",
        type=int,
        default=0,
        help="Partition index.",
    )
    parser.add_argument(
        "--min-p2p",
        type=float,
        default=0.1,
        help="Min prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
    )
    parser.add_argument(
        "--max-p2p",
        type=float,
        default=0.9,
        help="Max prompt2prompt threshold (portion of denoising for which to fix self attention maps).",
    )
    parser.add_argument(
        "--min-cfg",
        type=float,
        default=7.5,
        help="Min classifier free guidance scale.",
    )
    parser.add_argument(
        "--max-cfg",
        type=float,
        default=15,
        help="Max classifier free guidance scale.",
    )
    parser.add_argument(
        "--clip-threshold",
        type=float,
        default=0.2,
        help="CLIP threshold for text-image similarity of each image.",
    )
    parser.add_argument(
        "--clip-dir-threshold",
        type=float,
        default=0.2,
        help="Directional CLIP threshold for similarity of change between pairs of text and pairs of images.",
    )
    parser.add_argument(
        "--clip-img-threshold",
        type=float,
        default=0.7,
        help="CLIP threshold for image-image similarity.",
    )
    opt = parser.parse_args()

    global_seed = torch.randint(1 << 32, ()).item()
    print(f"Global seed: {global_seed}")
    seed_everything(global_seed)

    model = load_model_from_config(
        OmegaConf.load("stable_diffusion/configs/stable-diffusion/v1-inference.yaml"),
        ckpt=opt.ckpt,
        vae_ckpt=opt.vae_ckpt,
    )
    model.cuda().eval()
    model_wrap = k_diffusion.external.CompVisDenoiser(model)

    clip_similarity = ClipSimilarity().cuda()

    out_dir = Path(opt.out_dir)
    out_dir.mkdir(exist_ok=True, parents=True)

    with open(opt.prompts_file) as fp:
        prompts = [json.loads(line) for line in fp]

    print(f"Partition index {opt.partition} ({opt.partition + 1} / {opt.n_partitions})")
    prompts = np.array_split(list(enumerate(prompts)), opt.n_partitions)[opt.partition]

    with torch.no_grad(), torch.autocast("cuda"), model.ema_scope():
        uncond = model.get_learned_conditioning(2 * [""])
        sigmas = model_wrap.get_sigmas(opt.steps)

        for i, prompt in tqdm(prompts, desc="Prompts"):
            prompt_dir = out_dir.joinpath(f"{i:07d}")
            prompt_dir.mkdir(exist_ok=True)

            with open(prompt_dir.joinpath("prompt.json"), "w") as fp:
                json.dump(prompt, fp)

            cond = model.get_learned_conditioning([prompt["input"], prompt["output"]])
            results = {}

            with tqdm(total=opt.n_samples, desc="Samples") as progress_bar:

                while len(results) < opt.n_samples:
                    seed = torch.randint(1 << 32, ()).item()
                    if seed in results:
                        continue
                    torch.manual_seed(seed)

                    x = torch.randn(1, 4, 512 // 8, 512 // 8, device="cuda") * sigmas[0]
                    x = repeat(x, "1 ... -> n ...", n=2)

                    model_wrap_cfg = CFGDenoiser(model_wrap)
                    p2p_threshold = opt.min_p2p + torch.rand(()).item() * (opt.max_p2p - opt.min_p2p)
                    cfg_scale = opt.min_cfg + torch.rand(()).item() * (opt.max_cfg - opt.min_cfg)
                    extra_args = {"cond": cond, "uncond": uncond, "cfg_scale": cfg_scale}
                    samples_ddim = sample_euler_ancestral(model_wrap_cfg, x, sigmas, p2p_threshold, **extra_args)
                    x_samples_ddim = model.decode_first_stage(samples_ddim)
                    x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

                    x0 = x_samples_ddim[0]
                    x1 = x_samples_ddim[1]

                    clip_sim_0, clip_sim_1, clip_sim_dir, clip_sim_image = clip_similarity(
                        x0[None], x1[None], [prompt["input"]], [prompt["output"]]
                    )

                    results[seed] = dict(
                        image_0=to_pil(x0),
                        image_1=to_pil(x1),
                        p2p_threshold=p2p_threshold,
                        cfg_scale=cfg_scale,
                        clip_sim_0=clip_sim_0[0].item(),
                        clip_sim_1=clip_sim_1[0].item(),
                        clip_sim_dir=clip_sim_dir[0].item(),
                        clip_sim_image=clip_sim_image[0].item(),
                    )

                    progress_bar.update()

            # CLIP filter to get best samples for each prompt.
            metadata = [
                (result["clip_sim_dir"], seed)
                for seed, result in results.items()
                if result["clip_sim_image"] >= opt.clip_img_threshold
                and result["clip_sim_dir"] >= opt.clip_dir_threshold
                and result["clip_sim_0"] >= opt.clip_threshold
                and result["clip_sim_1"] >= opt.clip_threshold
            ]
            metadata.sort(reverse=True)
            for _, seed in metadata[: opt.max_out_samples]:
                result = results[seed]
                image_0 = result.pop("image_0")
                image_1 = result.pop("image_1")
                image_0.save(prompt_dir.joinpath(f"{seed}_0.jpg"), quality=100)
                image_1.save(prompt_dir.joinpath(f"{seed}_1.jpg"), quality=100)
                with open(prompt_dir.joinpath(f"metadata.jsonl"), "a") as fp:
                    fp.write(f"{json.dumps(dict(seed=seed, **result))}\n")

    print("Done.")


if __name__ == "__main__":
    main()