import torch
import inversion_run_base

from diffusers import (
    DDPMScheduler,
    DiffusionPipeline,
    T2IAdapter,
    MultiAdapter,
)
from controlnet_aux import (
    LineartDetector,
    CannyDetector,
)
from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img import retrieve_timesteps, retrieve_latents
from PIL import Image
from inversion_utils import get_ddpm_inversion_scheduler, create_xts
from config import get_config, get_num_steps_actual
from functools import partial
from compel import Compel, ReturnedEmbeddingsType


args = inversion_run_base.args

generator = None
device = "cuda" if torch.cuda.is_available() else "cpu"

lineart_detector = LineartDetector.from_pretrained("lllyasviel/Annotators")
lineart_detector = lineart_detector.to(device)

canndy_detector = CannyDetector()

adapters = MultiAdapter(
    [
        T2IAdapter.from_pretrained(
            "TencentARC/t2i-adapter-lineart-sdxl-1.0",
            torch_dtype=torch.float16,
            varient="fp16",
        ),
        T2IAdapter.from_pretrained(
            "TencentARC/t2i-adapter-canny-sdxl-1.0",
            torch_dtype=torch.float16,
            varient="fp16",
        ),
    ]
)
adapters = adapters.to(torch.float16)

pipeline = DiffusionPipeline.from_pipe(
    inversion_run_base.pipeline,
    adapter=adapters,
    custom_pipeline="./pipelines/pipeline_sdxl_adapter_img2img.py",
)
pipeline = pipeline.to(device)

pipeline.scheduler = DDPMScheduler.from_config(
    inversion_run_base.pipeline.scheduler.config,
)

config = get_config(args)

compel_proc = Compel(
  tokenizer=[pipeline.tokenizer, pipeline.tokenizer_2] ,
  text_encoder=[pipeline.text_encoder, pipeline.text_encoder_2],
  returned_embeddings_type=ReturnedEmbeddingsType.PENULTIMATE_HIDDEN_STATES_NON_NORMALIZED,
  requires_pooled=[False, True]
)

def run(
    input_image:Image,
    src_prompt:str,
    tgt_prompt:str,
    generate_size:int,
    seed:int,
    w1:float,
    w2:float,
    num_steps:int,
    start_step:int,
    guidance_scale:float,
    lineart_scale:float = 0.5,
    canny_scale:float = 0.5,
    lineart_detect:float = 0.375,
    canny_detect:float = 0.375,
):
    generator = torch.Generator().manual_seed(seed)

    config.num_steps_inversion = num_steps
    config.step_start = start_step
    num_steps_actual = get_num_steps_actual(config)
    

    num_steps_inversion = config.num_steps_inversion
    denoising_start = (num_steps_inversion - num_steps_actual) / num_steps_inversion
    
    timesteps, num_inference_steps = retrieve_timesteps(
        pipeline.scheduler, num_steps_inversion, device, None
    )
    timesteps, num_inference_steps = pipeline.get_timesteps(
        num_inference_steps=num_inference_steps,
        denoising_start=denoising_start,
        strength=0,
        device=device,
    )
    timesteps = timesteps.type(torch.int64)

    timesteps = [torch.tensor(t) for t in timesteps.tolist()]
    timesteps_len = len(timesteps)
    config.step_start = start_step + num_steps_actual - timesteps_len
    num_steps_actual = timesteps_len
    config.max_norm_zs = [-1] * (num_steps_actual - 1) + [15.5]
    
    lineart_image = lineart_detector(input_image, detect_resolution=int(generate_size * lineart_detect), image_resolution=generate_size)
    canny_image = canndy_detector(input_image, detect_resolution=int(generate_size * canny_detect), image_resolution=generate_size)
    cond_image = [lineart_image, canny_image]
    conditioning_scale = [lineart_scale, canny_scale]
    pipeline.__call__ = partial(
        pipeline.__call__,
        num_inference_steps=num_steps_inversion,
        guidance_scale=guidance_scale,
        generator=generator,
        denoising_start=denoising_start,
        strength=0,
        adapter_image=cond_image,
        adapter_conditioning_scale=conditioning_scale,
    )

    x_0_image = input_image
    x_0 = encode_image(x_0_image, pipeline)
    x_ts = create_xts(1, None, 0, generator, pipeline.scheduler, timesteps, x_0, no_add_noise=False)
    x_ts = [xt.to(dtype=torch.float16) for xt in x_ts]
    latents = [x_ts[0]]
    x_ts_c_hat = [None]
    config.ws1 = [w1] * num_steps_actual
    config.ws2 = [w2] * num_steps_actual
    pipeline.scheduler = get_ddpm_inversion_scheduler(
                    pipeline.scheduler,
                    config.step_function,
                    config,
                    timesteps,
                    config.save_timesteps,
                    latents,
                    x_ts,
                    x_ts_c_hat,
                    args.save_intermediate_results,
                    pipeline,
                    x_0,
                    v1s_images := [],
                    v2s_images := [],
                    deltas_images := [],
                    v1_x0s := [],
                    v2_x0s := [],
                    deltas_x0s := [],
                    "res12",
                    image_name="im_name",
                    time_measure_n=args.time_measure_n,
                )
    latent = latents[0].expand(3, -1, -1, -1)
    prompt = [src_prompt, src_prompt, tgt_prompt]
    conditioning, pooled = compel_proc(prompt)

    image = pipeline.__call__(
        image=latent,
        prompt_embeds=conditioning,
        pooled_prompt_embeds=pooled,
        eta=1,
    ).images
    return image[2]

def encode_image(image, pipe):
    image = pipe.image_processor.preprocess(image)
    originDtype = pipe.dtype
    image = image.to(device=device, dtype=originDtype)

    if pipe.vae.config.force_upcast:
        image = image.float()
        pipe.vae.to(dtype=torch.float32)

    if isinstance(generator, list):
        init_latents = [
            retrieve_latents(pipe.vae.encode(image[i : i + 1]), generator=generator[i])
            for i in range(1)
        ]
        init_latents = torch.cat(init_latents, dim=0)
    else:
        init_latents = retrieve_latents(pipe.vae.encode(image), generator=generator)

    if pipe.vae.config.force_upcast:
        pipe.vae.to(originDtype)

    init_latents = init_latents.to(originDtype)
    init_latents = pipe.vae.config.scaling_factor * init_latents

    return init_latents.to(dtype=torch.float16)

def get_timesteps(pipe, num_inference_steps, strength, device, denoising_start=None):
    # get the original timestep using init_timestep
    if denoising_start is None:
        init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
        t_start = max(num_inference_steps - init_timestep, 0)
    else:
        t_start = 0

    timesteps = pipe.scheduler.timesteps[t_start * pipe.scheduler.order :]

    # Strength is irrelevant if we directly request a timestep to start at;
    # that is, strength is determined by the denoising_start instead.
    if denoising_start is not None:
        discrete_timestep_cutoff = int(
            round(
                pipe.scheduler.config.num_train_timesteps
                - (denoising_start * pipe.scheduler.config.num_train_timesteps)
            )
        )

        num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
        if pipe.scheduler.order == 2 and num_inference_steps % 2 == 0:
            # if the scheduler is a 2nd order scheduler we might have to do +1
            # because `num_inference_steps` might be even given that every timestep
            # (except the highest one) is duplicated. If `num_inference_steps` is even it would
            # mean that we cut the timesteps in the middle of the denoising step
            # (between 1st and 2nd derivative) which leads to incorrect results. By adding 1
            # we ensure that the denoising process always ends after the 2nd derivate step of the scheduler
            num_inference_steps = num_inference_steps + 1

        # because t_n+1 >= t_n, we slice the timesteps starting from the end
        timesteps = timesteps[-num_inference_steps:]
        return timesteps, num_inference_steps

    return timesteps, num_inference_steps - t_start