import io
import os
import torch
import zipfile
import spaces
import numpy as np
import gradio as gr
from PIL import Image
from tqdm.auto import tqdm
from src.util.params import *
from src.util.clip_config import *
import matplotlib.pyplot as plt

@spaces.GPU(enable_queue=True)
def get_text_embeddings(
    prompt,
    tokenizer=tokenizer,
    text_encoder=text_encoder,
    torch_device=torch_device,
    batch_size=1,
    negative_prompt="",
):
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    )

    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(torch_device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer(
        [negative_prompt] * batch_size,
        padding="max_length",
        max_length=max_length,
        return_tensors="pt",
    )
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
    text_embeddings = torch.cat([uncond_embeddings, text_embeddings])

    return text_embeddings

@spaces.GPU(enable_queue=True)
def generate_latents(
    seed,
    height=imageHeight,
    width=imageWidth,
    torch_device=torch_device,
    unet=unet,
    batch_size=1,
):
    generator = torch.Generator().manual_seed(int(seed))

    latents = torch.randn(
        (batch_size, unet.config.in_channels, height // 8, width // 8),
        generator=generator,
    ).to(torch_device)

    return latents

@spaces.GPU(enable_queue=True)
def generate_modified_latents(
    poke,
    seed,
    pokeX=None,
    pokeY=None,
    pokeHeight=None,
    pokeWidth=None,
    imageHeight=imageHeight,
    imageWidth=imageWidth,
):
    original_latents = generate_latents(seed, height=imageHeight, width=imageWidth)
    if poke:
        np.random.seed(seed)
        poke_latents = generate_latents(
            np.random.randint(0, 100000), height=pokeHeight * 8, width=pokeWidth * 8
        )

        x_origin = pokeX - pokeWidth // 2
        y_origin = pokeY - pokeHeight // 2

        modified_latents = original_latents.clone()
        modified_latents[
            :, :, y_origin : y_origin + pokeHeight, x_origin : x_origin + pokeWidth
        ] = poke_latents
    else:
        modified_latents = None

    return original_latents, modified_latents


def convert_to_pil_image(image):
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
    images = (image * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]
    return pil_images[0]

@spaces.GPU(enable_queue=True)
def generate_images(
    latents,
    text_embeddings,
    num_inference_steps,
    unet=unet,
    guidance_scale=guidance_scale,
    vae=vae,
    scheduler=scheduler,
    intermediate=False,
    progress=gr.Progress(),
):
    scheduler.set_timesteps(num_inference_steps)
    latents = latents * scheduler.init_noise_sigma
    images = []
    i = 1

    for t in tqdm(scheduler.timesteps):
        latent_model_input = torch.cat([latents] * 2)
        latent_model_input = scheduler.scale_model_input(latent_model_input, t)

        with torch.no_grad():
            noise_pred = unet(
                latent_model_input, t, encoder_hidden_states=text_embeddings
            ).sample

        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (
            noise_pred_text - noise_pred_uncond
        )

        if intermediate:
            progress(((1000 - t) / 1000))
            Latents = 1 / 0.18215 * latents
            with torch.no_grad():
                image = vae.decode(Latents).sample
                images.append((convert_to_pil_image(image), "{}".format(i)))

        latents = scheduler.step(noise_pred, t, latents).prev_sample
        i += 1

    if not intermediate:
        Latents = 1 / 0.18215 * latents
        with torch.no_grad():
            image = vae.decode(Latents).sample
        images = convert_to_pil_image(image)

    return images

@spaces.GPU(enable_queue=True)
def get_word_embeddings(
    prompt, tokenizer=tokenizer, text_encoder=text_encoder, torch_device=torch_device
):
    text_input = tokenizer(
        prompt,
        padding="max_length",
        max_length=tokenizer.model_max_length,
        truncation=True,
        return_tensors="pt",
    ).to(torch_device)

    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids)[0].reshape(1, -1)

    text_embeddings = text_embeddings.cpu().numpy()
    return text_embeddings / np.linalg.norm(text_embeddings)


def get_concat_embeddings(names, merge=False):
    embeddings = []

    for name in names:
        embedding = get_word_embeddings(name)
        embeddings.append(embedding)

    embeddings = np.vstack(embeddings)

    if merge:
        embeddings = np.average(embeddings, axis=0).reshape(1, -1)

    return embeddings


def get_axis_embeddings(A, B):
    emb = []

    for a, b in zip(A, B):
        e = get_word_embeddings(a) - get_word_embeddings(b)
        emb.append(e)

    emb = np.vstack(emb)
    ax = np.average(emb, axis=0).reshape(1, -1)

    return ax


def calculate_residual(
    axis, axis_names, from_words=None, to_words=None, residual_axis=1
):
    axis_indices = [0, 1, 2]
    axis_indices.remove(residual_axis)

    if axis_names[axis_indices[0]] in axis_combinations:
        fembeddings = get_concat_embeddings(
            axis_combinations[axis_names[axis_indices[0]]], merge=True
        )
    else:
        axis_combinations[axis_names[axis_indices[0]]] = from_words + to_words
        fembeddings = get_concat_embeddings(from_words + to_words, merge=True)

    if axis_names[axis_indices[1]] in axis_combinations:
        sembeddings = get_concat_embeddings(
            axis_combinations[axis_names[axis_indices[1]]], merge=True
        )
    else:
        axis_combinations[axis_names[axis_indices[1]]] = from_words + to_words
        sembeddings = get_concat_embeddings(from_words + to_words, merge=True)

    fprojections = fembeddings @ axis[axis_indices[0]].T
    sprojections = sembeddings @ axis[axis_indices[1]].T

    partial_residual = fembeddings - (fprojections.reshape(-1, 1) * fembeddings)
    residual = partial_residual - (sprojections.reshape(-1, 1) * sembeddings)

    return residual


def calculate_step_size(num_images, start_degree_circular, end_degree_circular):
    return (end_degree_circular - start_degree_circular) / (num_images)


def generate_seed_vis(seed):
    np.random.seed(seed)
    emb = np.random.rand(15)
    plt.close()
    plt.switch_backend("agg")
    plt.figure(figsize=(10, 0.5))
    plt.imshow([emb], cmap="viridis")
    plt.axis("off")
    return plt


def export_as_gif(images, filename, frames_per_second=2, reverse=False):
    imgs = [img[0] for img in images]

    if reverse:
        imgs += imgs[2:-1][::-1]

    imgs[0].save(
        f"outputs/{filename}",
        format="GIF",
        save_all=True,
        append_images=imgs[1:],
        duration=1000 // frames_per_second,
        loop=0,
    )


def export_as_zip(images, fname, tab_config=None):

    if not os.path.exists(f"outputs/{fname}.zip"):
        os.makedirs("outputs", exist_ok=True)

    with zipfile.ZipFile(f"outputs/{fname}.zip", "w") as img_zip:

        if tab_config:
            with open("outputs/config.txt", "w") as f:
                for key, value in tab_config.items():
                    f.write(f"{key}: {value}\n")
                f.close()

            img_zip.write("outputs/config.txt", "config.txt")

        for idx, img in enumerate(images):
            buff = io.BytesIO()
            img[0].save(buff, format="PNG")
            buff = buff.getvalue()
            max_num = len(images)
            num_leading_zeros = len(str(max_num))
            img_name = f"{{:0{num_leading_zeros}}}.png"
            img_zip.writestr(img_name.format(idx + 1), buff)


def read_html(file_path):
    with open(file_path, "r", encoding="utf-8") as f:
        content = f.read()
    return content


__all__ = [
    "get_text_embeddings",
    "generate_latents",
    "generate_modified_latents",
    "generate_images",
    "get_word_embeddings",
    "get_concat_embeddings",
    "get_axis_embeddings",
    "calculate_residual",
    "calculate_step_size",
    "generate_seed_vis",
    "export_as_gif",
    "export_as_zip",
    "read_html",
]