from transformers import AutoTokenizer
from PIL import Image
import cv2
import torch
from omegaconf import OmegaConf
import math
from copy import deepcopy
import torch.nn.functional as F
import numpy as np
import clip
from transformers import AutoTokenizer

from kandinsky2.model.text_encoders import TextEncoder
from kandinsky2.vqgan.autoencoder import VQModelInterface, AutoencoderKL, MOVQ
from kandinsky2.model.samplers import DDIMSampler, PLMSSampler
from kandinsky2.model.model_creation import create_model, create_gaussian_diffusion
from kandinsky2.model.prior import PriorDiffusionModel, CustomizedTokenizer
from kandinsky2.utils import prepare_image, q_sample, process_images, prepare_mask


class Kandinsky2_1:
    
    def __init__(
        self, 
        config, 
        model_path, 
        prior_path, 
        device, 
        task_type="text2img"
    ):
        self.config = config
        self.device = device
        self.use_fp16 = self.config["model_config"]["use_fp16"]
        self.task_type = task_type
        self.clip_image_size = config["clip_image_size"]
        if task_type == "text2img":
            self.config["model_config"]["up"] = False
            self.config["model_config"]["inpainting"] = False
        elif task_type == "inpainting":
            self.config["model_config"]["up"] = False
            self.config["model_config"]["inpainting"] = True
        else:
            raise ValueError("Only text2img and inpainting is available")

        self.tokenizer1 = AutoTokenizer.from_pretrained(self.config["tokenizer_name"])
        self.tokenizer2 = CustomizedTokenizer()
        clip_mean, clip_std = torch.load(
            config["prior"]["clip_mean_std_path"], map_location="cpu"
        )

        self.prior = PriorDiffusionModel(
            config["prior"]["params"],
            self.tokenizer2,
            clip_mean,
            clip_std,
        )
        self.prior.load_state_dict(torch.load(prior_path, map_location='cpu'), strict=False)
        if self.use_fp16:
            self.prior = self.prior.half()
        self.text_encoder = TextEncoder(**self.config["text_enc_params"])
        if self.use_fp16:
            self.text_encoder = self.text_encoder.half()

        self.clip_model, self.preprocess = clip.load(
            config["clip_name"], device=self.device, jit=False
        )
        self.clip_model.eval()

        if self.config["image_enc_params"] is not None:
            self.use_image_enc = True
            self.scale = self.config["image_enc_params"]["scale"]
            if self.config["image_enc_params"]["name"] == "AutoencoderKL":
                self.image_encoder = AutoencoderKL(
                    **self.config["image_enc_params"]["params"]
                )
            elif self.config["image_enc_params"]["name"] == "VQModelInterface":
                self.image_encoder = VQModelInterface(
                    **self.config["image_enc_params"]["params"]
                )
            elif self.config["image_enc_params"]["name"] == "MOVQ":
                self.image_encoder = MOVQ(**self.config["image_enc_params"]["params"])
                self.image_encoder.load_state_dict(
                    torch.load(self.config["image_enc_params"]["ckpt_path"], map_location='cpu')
                )
            self.image_encoder.eval()
        else:
            self.use_image_enc = False
            
        self.config["model_config"]["cache_text_emb"] = True
        self.model = create_model(**self.config["model_config"])
        self.model.load_state_dict(torch.load(model_path, map_location='cpu'))
        if self.use_fp16:
            self.model.convert_to_fp16()
            self.image_encoder = self.image_encoder.half()

            self.model_dtype = torch.float16
        else:
            self.model_dtype = torch.float32
            
        self.image_encoder = self.image_encoder.to(self.device).eval()
        self.text_encoder = self.text_encoder.to(self.device).eval()
        self.prior = self.prior.to(self.device).eval()
        self.model.eval()
        self.model.to(self.device)

    def get_new_h_w(self, h, w):
        new_h = h // 64
        if h % 64 != 0:
            new_h += 1
        new_w = w // 64
        if w % 64 != 0:
            new_w += 1
        return new_h * 8, new_w * 8

    @torch.no_grad()
    def encode_text(self, text_encoder, tokenizer, prompt, batch_size):
        text_encoding = tokenizer(
            [prompt] * batch_size + [""] * batch_size,
            max_length=77,
            padding="max_length",
            truncation=True,
            return_attention_mask=True,
            add_special_tokens=True,
            return_tensors="pt",
        )

        tokens = text_encoding["input_ids"].to(self.device)
        mask = text_encoding["attention_mask"].to(self.device)

        full_emb, pooled_emb = text_encoder(tokens=tokens, mask=mask)
        return full_emb, pooled_emb

    @torch.no_grad()
    def generate_clip_emb(
        self,
        prompt,
        batch_size=1,
        prior_cf_scale=4,
        prior_steps="25",
        negative_prior_prompt="",
    ):
        prompts_batch = [prompt for _ in range(batch_size)]
        prior_cf_scales_batch = [prior_cf_scale] * len(prompts_batch)
        prior_cf_scales_batch = torch.tensor(prior_cf_scales_batch, device=self.device)
        max_txt_length = self.prior.model.text_ctx
        tok, mask = self.tokenizer2.padded_tokens_and_mask(
            prompts_batch, max_txt_length
        )
        cf_token, cf_mask = self.tokenizer2.padded_tokens_and_mask(
            [negative_prior_prompt], max_txt_length
        )
        if not (cf_token.shape == tok.shape):
            cf_token = cf_token.expand(tok.shape[0], -1)
            cf_mask = cf_mask.expand(tok.shape[0], -1)
        tok = torch.cat([tok, cf_token], dim=0)
        mask = torch.cat([mask, cf_mask], dim=0)
        tok, mask = tok.to(device=self.device), mask.to(device=self.device)

        x = self.clip_model.token_embedding(tok).type(self.clip_model.dtype)
        x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
        x = x.permute(1, 0, 2)  # NLD -> LND|
        x = self.clip_model.transformer(x)
        x = x.permute(1, 0, 2)  # LND -> NLD
        x = self.clip_model.ln_final(x).type(self.clip_model.dtype)
        txt_feat_seq = x
        txt_feat = (x[torch.arange(x.shape[0]), tok.argmax(dim=-1)] @ self.clip_model.text_projection)
        txt_feat, txt_feat_seq = txt_feat.float().to(self.device), txt_feat_seq.float().to(self.device)
        img_feat = self.prior(
            txt_feat,
            txt_feat_seq,
            mask,
            prior_cf_scales_batch,
            timestep_respacing=prior_steps,
        )
        return img_feat.to(self.model_dtype)

    @torch.no_grad()
    def encode_images(self, image, is_pil=False):
        if is_pil:
            image = self.preprocess(image).unsqueeze(0).to(self.device)
        return self.clip_model.encode_image(image).to(self.model_dtype)

    @torch.no_grad()
    def generate_img(
        self,
        prompt,
        img_prompt,
        batch_size=1,
        diffusion=None,
        guidance_scale=7,
        init_step=None,
        noise=None,
        init_img=None,
        img_mask=None,
        h=512,
        w=512,
        sampler="ddim_sampler",
        num_steps=50,
    ):
        new_h, new_w = self.get_new_h_w(h, w)
        full_batch_size = batch_size * 2
        model_kwargs = {}

        if init_img is not None and self.use_fp16:
            init_img = init_img.half()
        if img_mask is not None and self.use_fp16:
            img_mask = img_mask.half()
        model_kwargs["full_emb"], model_kwargs["pooled_emb"] = self.encode_text(
            text_encoder=self.text_encoder,
            tokenizer=self.tokenizer1,
            prompt=prompt,
            batch_size=batch_size,
        )
        model_kwargs["image_emb"] = img_prompt

        if self.task_type == "inpainting":
            init_img = init_img.to(self.device)
            img_mask = img_mask.to(self.device)
            model_kwargs["inpaint_image"] = init_img * img_mask
            model_kwargs["inpaint_mask"] = img_mask

        def model_fn(x_t, ts, **kwargs):
            half = x_t[: len(x_t) // 2]
            combined = torch.cat([half, half], dim=0)
            model_out = self.model(combined, ts, **kwargs)
            eps, rest = model_out[:, :4], model_out[:, 4:]
            cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
            half_eps = uncond_eps + guidance_scale * (cond_eps - uncond_eps)
            eps = torch.cat([half_eps, half_eps], dim=0)
            if sampler == "p_sampler":
                return torch.cat([eps, rest], dim=1)
            else:
                return eps

        if noise is not None:
            noise = noise.float()
        if self.task_type == "inpainting":
            def denoised_fun(x_start):
                x_start = x_start.clamp(-2, 2)
                return x_start * (1 - img_mask) + init_img * img_mask
        else:
            def denoised_fun(x):
                return x.clamp(-2, 2)

        if sampler == "p_sampler":
            self.model.del_cache()
            samples = diffusion.p_sample_loop(
                model_fn,
                (full_batch_size, 4, new_h, new_w),
                device=self.device,
                noise=noise,
                progress=True,
                model_kwargs=model_kwargs,
                init_step=init_step,
                denoised_fn=denoised_fun,
            )[:batch_size]
            self.model.del_cache()
        else:
            if sampler == "ddim_sampler":
                sampler = DDIMSampler(
                    model=model_fn,
                    old_diffusion=diffusion,
                    schedule="linear",
                )
            elif sampler == "plms_sampler":
                sampler = PLMSSampler(
                    model=model_fn,
                    old_diffusion=diffusion,
                    schedule="linear",
                )
            else:
                raise ValueError("Only ddim_sampler and plms_sampler is available")
                
            self.model.del_cache()
            samples, _ = sampler.sample(
                num_steps,
                batch_size * 2,
                (4, new_h, new_w),
                conditioning=model_kwargs,
                x_T=noise,
                init_step=init_step,
            )
            self.model.del_cache()
            samples = samples[:batch_size]
            
        if self.use_image_enc:
            if self.use_fp16:
                samples = samples.half()
            samples = self.image_encoder.decode(samples / self.scale)
            
        samples = samples[:, :, :h, :w]
        return process_images(samples)

    @torch.no_grad()
    def create_zero_img_emb(self, batch_size):
        img = torch.zeros(1, 3, self.clip_image_size, self.clip_image_size).to(self.device)
        return self.encode_images(img, is_pil=False).repeat(batch_size, 1)

    @torch.no_grad()
    def generate_text2img(
        self,
        prompt,
        num_steps=100,
        batch_size=1,
        guidance_scale=7,
        h=512,
        w=512,
        sampler="ddim_sampler",
        prior_cf_scale=4,
        prior_steps="25",
        negative_prior_prompt="",
        negative_decoder_prompt="",
    ):
        # generate clip embeddings
        image_emb = self.generate_clip_emb(
            prompt,
            batch_size=batch_size,
            prior_cf_scale=prior_cf_scale,
            prior_steps=prior_steps,
            negative_prior_prompt=negative_prior_prompt,
        )
        if negative_decoder_prompt == "":
            zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
        else:
            zero_image_emb = self.generate_clip_emb(
                negative_decoder_prompt,
                batch_size=batch_size,
                prior_cf_scale=prior_cf_scale,
                prior_steps=prior_steps,
                negative_prior_prompt=negative_prior_prompt,
            )

        image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
        
        # load diffusion
        config = deepcopy(self.config)
        if sampler == "p_sampler":
            config["diffusion_config"]["timestep_respacing"] = str(num_steps)
        diffusion = create_gaussian_diffusion(**config["diffusion_config"])
        
        return self.generate_img(
            prompt=prompt,
            img_prompt=image_emb,
            batch_size=batch_size,
            guidance_scale=guidance_scale,
            h=h,
            w=w,
            sampler=sampler,
            num_steps=num_steps,
            diffusion=diffusion,
        )

    @torch.no_grad()
    def mix_images(
        self,
        images_texts,
        weights,
        num_steps=100,
        batch_size=1,
        guidance_scale=7,
        h=512,
        w=512,
        sampler="ddim_sampler",
        prior_cf_scale=4,
        prior_steps="25",
        negative_prior_prompt="",
        negative_decoder_prompt="",
    ):
        assert len(images_texts) == len(weights) and len(images_texts) > 0
        
        # generate clip embeddings
        image_emb = None
        for i in range(len(images_texts)):
            if image_emb is None:
                if type(images_texts[i]) == str:
                    image_emb = weights[i] * self.generate_clip_emb(
                        images_texts[i],
                        batch_size=1,
                        prior_cf_scale=prior_cf_scale,
                        prior_steps=prior_steps,
                        negative_prior_prompt=negative_prior_prompt,
                    )
                else:
                    image_emb = self.encode_images(images_texts[i], is_pil=True) * weights[i]
            else:
                if type(images_texts[i]) == str:
                    image_emb = image_emb + weights[i] * self.generate_clip_emb(
                        images_texts[i],
                        batch_size=1,
                        prior_cf_scale=prior_cf_scale,
                        prior_steps=prior_steps,
                        negative_prior_prompt=negative_prior_prompt,
                    )
                else:
                    image_emb = image_emb + self.encode_images(images_texts[i], is_pil=True) * weights[i]
                    
        image_emb = image_emb.repeat(batch_size, 1)
        if negative_decoder_prompt == "":
            zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
        else:
            zero_image_emb = self.generate_clip_emb(
                negative_decoder_prompt,
                batch_size=batch_size,
                prior_cf_scale=prior_cf_scale,
                prior_steps=prior_steps,
                negative_prior_prompt=negative_prior_prompt,
            )
        image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
        
        # load diffusion
        config = deepcopy(self.config)
        if sampler == "p_sampler":
            config["diffusion_config"]["timestep_respacing"] = str(num_steps)
        diffusion = create_gaussian_diffusion(**config["diffusion_config"])
        return self.generate_img(
            prompt="",
            img_prompt=image_emb,
            batch_size=batch_size,
            guidance_scale=guidance_scale,
            h=h,
            w=w,
            sampler=sampler,
            num_steps=num_steps,
            diffusion=diffusion,
        )

    @torch.no_grad()
    def generate_img2img(
        self,
        prompt,
        pil_img,
        strength=0.7,
        num_steps=100,
        batch_size=1,
        guidance_scale=7,
        h=512,
        w=512,
        sampler="ddim_sampler",
        prior_cf_scale=4,
        prior_steps="25",
    ):
        # generate clip embeddings
        image_emb = self.generate_clip_emb(
            prompt,
            batch_size=batch_size,
            prior_cf_scale=prior_cf_scale,
            prior_steps=prior_steps,
        )
        zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
        image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
        
        # load diffusion
        config = deepcopy(self.config)
        if sampler == "p_sampler":
            config["diffusion_config"]["timestep_respacing"] = str(num_steps)
        diffusion = create_gaussian_diffusion(**config["diffusion_config"])
        
        image = prepare_image(pil_img, h=h, w=w).to(self.device)
        if self.use_fp16:
            image = image.half()
        image = self.image_encoder.encode(image) * self.scale
        
        start_step = int(diffusion.num_timesteps * (1 - strength))
        image = q_sample(
            image,
            torch.tensor(diffusion.timestep_map[start_step - 1]).to(self.device),
            schedule_name=config["diffusion_config"]["noise_schedule"],
            num_steps=config["diffusion_config"]["steps"],
        )
        
        image = image.repeat(2, 1, 1, 1)
        return self.generate_img(
            prompt=prompt,
            img_prompt=image_emb,
            batch_size=batch_size,
            guidance_scale=guidance_scale,
            h=h,
            w=w,
            sampler=sampler,
            num_steps=num_steps,
            diffusion=diffusion,
            noise=image,
            init_step=start_step,
        )

    @torch.no_grad()
    def generate_inpainting(
        self,
        prompt,
        pil_img,
        img_mask,
        num_steps=100,
        batch_size=1,
        guidance_scale=7,
        h=512,
        w=512,
        sampler="ddim_sampler",
        prior_cf_scale=4,
        prior_steps="25",
        negative_prior_prompt="",
        negative_decoder_prompt="",
    ):
        # generate clip embeddings
        image_emb = self.generate_clip_emb(
            prompt,
            batch_size=batch_size,
            prior_cf_scale=prior_cf_scale,
            prior_steps=prior_steps,
            negative_prior_prompt=negative_prior_prompt,
        )
        zero_image_emb = self.create_zero_img_emb(batch_size=batch_size)
        image_emb = torch.cat([image_emb, zero_image_emb], dim=0).to(self.device)
        
        # load diffusion
        config = deepcopy(self.config)
        if sampler == "p_sampler":
            config["diffusion_config"]["timestep_respacing"] = str(num_steps)
        diffusion = create_gaussian_diffusion(**config["diffusion_config"])
        image = prepare_image(pil_img, w, h).to(self.device)
        if self.use_fp16:
            image = image.half()
        image = self.image_encoder.encode(image) * self.scale
        image_shape = tuple(image.shape[-2:])
        img_mask = torch.from_numpy(img_mask).unsqueeze(0).unsqueeze(0)
        img_mask = F.interpolate(
            img_mask,
            image_shape,
            mode="nearest",
        )
        img_mask = prepare_mask(img_mask).to(self.device)
        if self.use_fp16:
            img_mask = img_mask.half()
        image = image.repeat(2, 1, 1, 1)
        img_mask = img_mask.repeat(2, 1, 1, 1)
        
        return self.generate_img(
            prompt=prompt,
            img_prompt=image_emb,
            batch_size=batch_size,
            guidance_scale=guidance_scale,
            h=h,
            w=w,
            sampler=sampler,
            num_steps=num_steps,
            diffusion=diffusion,
            init_img=image,
            img_mask=img_mask,
        )
import os
from huggingface_hub import hf_hub_url, cached_download
from copy import deepcopy
from omegaconf.dictconfig import DictConfig

def get_kandinsky2_1(
    device,
    task_type="text2img",
    cache_dir="/tmp/kandinsky2",
    use_auth_token=None,
    use_flash_attention=False,
):
    cache_dir = os.path.join(cache_dir, "2_1")
    config = DictConfig(deepcopy(CONFIG_2_1))
    config["model_config"]["use_flash_attention"] = use_flash_attention
    if task_type == "text2img":
        model_name = "decoder_fp16.ckpt"
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
    elif task_type == "inpainting":
        model_name = "inpainting_fp16.ckpt"
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=model_name)
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename=model_name,
        use_auth_token=use_auth_token,
    )
    prior_name = "prior_fp16.ckpt"
    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=prior_name)
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename=prior_name,
        use_auth_token=use_auth_token,
    )

    cache_dir_text_en = os.path.join(cache_dir, "text_encoder")
    for name in [
        "config.json",
        "pytorch_model.bin",
        "sentencepiece.bpe.model",
        "special_tokens_map.json",
        "tokenizer.json",
        "tokenizer_config.json",
    ]:
        config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename=f"text_encoder/{name}")
        cached_download(
            config_file_url,
            cache_dir=cache_dir_text_en,
            force_filename=name,
            use_auth_token=use_auth_token,
        )

    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="movq_final.ckpt")
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename="movq_final.ckpt",
        use_auth_token=use_auth_token,
    )

    config_file_url = hf_hub_url(repo_id="sberbank-ai/Kandinsky_2.1", filename="ViT-L-14_stats.th")
    cached_download(
        config_file_url,
        cache_dir=cache_dir,
        force_filename="ViT-L-14_stats.th",
        use_auth_token=use_auth_token,
    )

    config["tokenizer_name"] = cache_dir_text_en
    config["text_enc_params"]["model_path"] = cache_dir_text_en
    config["prior"]["clip_mean_std_path"] = os.path.join(cache_dir, "ViT-L-14_stats.th")
    config["image_enc_params"]["ckpt_path"] = os.path.join(cache_dir, "movq_final.ckpt")
    cache_model_name = os.path.join(cache_dir, model_name)
    cache_prior_name = os.path.join(cache_dir, prior_name)
    model = Kandinsky2_1(config, cache_model_name, cache_prior_name, device, task_type=task_type)
    return model


def get_kandinsky2(
    device,
    task_type="text2img",
    cache_dir="/tmp/kandinsky2",
    use_auth_token=None,
    model_version="2.1",
    use_flash_attention=False,
):
    if model_version == "2.0":
        model = get_kandinsky2_0(
            device,
            task_type=task_type,
            cache_dir=cache_dir,
            use_auth_token=use_auth_token,
        )
    elif model_version == "2.1":
        model = get_kandinsky2_1(
            device,
            task_type=task_type,
            cache_dir=cache_dir,
            use_auth_token=use_auth_token,
            use_flash_attention=use_flash_attention,
        )
    elif model_version == "2.2":
        model = Kandinsky2_2(device=device, task_type=task_type)
    else:
        raise ValueError("Only 2.0 and 2.1 is available")
    
    return model