import datetime
import os
from collections import OrderedDict
from typing import Any

import gradio as gr
import spaces
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    LogitsProcessorList,
    TextStreamer,
)

from cache_system import CacheHandler
from download_url import download_text_and_title
from prompts import (
    summarize_clickbait_large_prompt,
    summarize_clickbait_short_prompt,
    summarize_prompt,
)
from utils import StopAfterTokenIsGenerated

auth_token = os.environ.get("TOKEN") or True

total_runs = 0

tokenizer = AutoTokenizer.from_pretrained("Iker/ClickbaitFighter-10B-pro")
model = AutoModelForCausalLM.from_pretrained(
    "Iker/ClickbaitFighter-10B-pro",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    # quantization_config=BitsAndBytesConfig(
    #    load_in_4bit=True,
    #    bnb_4bit_compute_dtype=torch.bfloat16,
    #    bnb_4bit_use_double_quant=True,
    # ),
    # attn_implementation="flash_attention_2",
)

generation_config = GenerationConfig(
    max_new_tokens=256,  # Los resúmenes son cortos, no necesitamos más tokens
    min_new_tokens=1,  # No queremos resúmenes vacíos
    do_sample=True,  # Un poquito mejor que greedy sampling
    num_beams=1,
    use_cache=True,  # Eficiencia
    top_k=40,
    top_p=0.1,
    repetition_penalty=1.1,  # Ayuda a evitar que el modelo entre en bucles
    encoder_repetition_penalty=1.1,  # Favorecemos que el modelo cite el texto original
    temperature=0.15,  #  temperature baja para evitar que el modelo genere texto muy creativo.
    eos_token_id=tokenizer.eos_token_id,
    pad_token_id=tokenizer.eos_token_id,
)

stop_words = [
    "<s>",
    "</s>",
    "\\n",
    "[/INST]",
    "[INST]",
    "### User:",
    "### Assistant:",
    "###",
    "<start_of_turn>",
    "<end_of_turn>",
    "<end_of_turn>\\n",
    "<eos>",
    "<|im_end|>",
]


stop_criteria = LogitsProcessorList(
    [
        StopAfterTokenIsGenerated(
            stops=[
                torch.tensor(tokenizer.encode(stop_word, add_special_tokens=False))
                for stop_word in stop_words.copy()
            ],
            eos_token_id=tokenizer.eos_token_id,
        )
    ]
)


class HuggingFaceDatasetSaver_custom(gr.HuggingFaceDatasetSaver):
    def _deserialize_components(
        self,
        data_dir,
        flag_data: list[Any],
        flag_option: str = "",
        username: str = "",
    ) -> tuple[dict[Any, Any], list[Any]]:
        """Deserialize components and return the corresponding row for the flagged sample.

        Images/audio are saved to disk as individual files.
        """

        # Generate the row corresponding to the flagged sample
        features = OrderedDict()
        row = []
        for component, sample in zip(self.components, flag_data):
            label = component.label or ""
            features[label] = {"dtype": "string", "_type": "Value"}
            row.append(sample)

        features["flag"] = {"dtype": "string", "_type": "Value"}
        features["username"] = {"dtype": "string", "_type": "Value"}
        row.append(flag_option)
        row.append(username)
        return features, row


def finish_generation(text: str) -> str:
    return f"{text}\n\n⬇️ Ayuda a mejorar la herramienta marcando si el resumen es correcto o no.⬇️"


@spaces.GPU
def run_model(mode, title, text):
    if mode == 0:
        prompt = summarize_prompt(title, text)
    elif mode == 50:
        prompt = summarize_clickbait_large_prompt(title, text)
    elif mode == 100:
        prompt = summarize_clickbait_short_prompt(title, text)
    else:
        raise ValueError("Mode not supported")

    formatted_prompt = tokenizer.apply_chat_template(
        [{"role": "user", "content": prompt}],
        tokenize=False,
        add_generation_prompt=True,
    )

    model_inputs = tokenizer(
        [formatted_prompt], return_tensors="pt", add_special_tokens=False
    )

    streamer = TextStreamer(
        tokenizer=tokenizer, skip_prompt=True, skip_special_tokens=True
    )

    model_output = model.generate(
        **model_inputs.to(model.device),
        streamer=streamer,
        generation_config=generation_config,
        logits_processor=stop_criteria,
    )

    # yield streamer # Does not work properly on Zero environment

    temp = tokenizer.batch_decode(
        model_output[:, model_inputs["input_ids"].shape[-1] :],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True,
    )[0]

    return temp


def generate_text(
    url: str, mode: int, progress=gr.Progress(track_tqdm=False)
) -> (str, str):
    global cache_handler
    global total_runs

    total_runs += 1
    print(f"Total runs: {total_runs}. Last run: {datetime.datetime.now()}")

    url = url.strip()

    if url.startswith("https://twitter.com/") or url.startswith("https://x.com/"):
        yield (
            "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
            "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
            "Error",
        )
        return (
            "🤖 Vaya, parece que has introducido la url de un tweet. No puedo acceder a tweets, tienes que introducir la URL de una noticia.",
            "❌❌❌ Si el tweet contiene una noticia, dame la URL de la noticia ❌❌❌",
            "Error",
        )

    # 1) Download the article

    progress(0, desc="🤖 Accediendo a la noticia")

    # First, check if the URL is in the cache
    title, text, temp = cache_handler.get_from_cache(url, mode)
    if title is not None and text is not None and temp is not None:
        temp = finish_generation(temp)
        yield title, temp, text
        return title, temp, text
    else:
        try:
            title, text, url = download_text_and_title(url)
        except Exception as e:
            print(e)
            title = None
            text = None

        if title is None or text is None:
            yield (
                "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
                "❌❌❌ Inténtalo de nuevo ❌❌❌",
                "Error",
            )
            return (
                "🤖 No he podido acceder a la notica, asegurate que la URL es correcta y que es posible acceder a la noticia desde un navegador.",
                "❌❌❌ Inténtalo de nuevo ❌❌❌",
                "Error",
            )

        # Test if the redirected and clean url is in the cache
        _, _, temp = cache_handler.get_from_cache(url, mode, second_try=True)
        if temp is not None:
            temp = finish_generation(temp)
            yield title, temp, text
            return title, temp, text

        progress(0.5, desc="🤖 Leyendo noticia")

        try:
            temp = run_model(mode, title, text)

        except Exception as e:
            print(e)
            yield (
                "🤖 El servidor no se encuentra disponible.",
                "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
                "Error",
            )
            return (
                "🤖 El servidor no se encuentra disponible.",
                "❌❌❌ Inténtalo de nuevo más tarde ❌❌❌",
                "Error",
            )

        cache_handler.add_to_cache(
            url=url, title=title, text=text, summary_type=mode, summary=temp
        )
        temp = finish_generation(temp)
        yield title, temp, text

    hits, misses, cache_len = cache_handler.get_cache_stats()
    print(
        f"Hits: {hits}, misses: {misses}, cache length: {cache_len}. Percent hits: {round(hits/(hits+misses)*100,2)}%."
    )
    return title, temp, text


cache_handler = CacheHandler(max_cache_size=1000)
hf_writer = HuggingFaceDatasetSaver_custom(
    auth_token, "Iker/Clickbait-News", private=True, separate_dirs=False
)


demo = gr.Interface(
    generate_text,
    inputs=[
        gr.Textbox(
            label="🌐 URL de la noticia",
            info="Introduce la URL de la noticia que deseas resumir.",
            value="https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/",
            interactive=True,
        ),
        gr.Slider(
            minimum=0,
            maximum=100,
            step=50,
            value=50,
            label="🎚️ Nivel de resumen",
            info="""¿Hasta qué punto quieres resumir la noticia? 

Si solo deseas un resumen, selecciona 0.

Si buscas un resumen y desmontar el clickbait, elige 50.

Para obtener solo la respuesta al clickbait, selecciona 100""",
            interactive=True,
        ),
    ],
    outputs=[
        gr.Textbox(
            label="📰 Titular de la noticia",
            interactive=False,
            placeholder="Aquí aparecerá el título de la noticia",
        ),
        gr.Textbox(
            label="🗒️ Resumen",
            interactive=False,
            placeholder="Aquí aparecerá el resumen de la noticia.",
        ),
        gr.Textbox(
            label="Noticia completa",
            visible=False,
            render=False,
            interactive=False,
            placeholder="Aquí aparecerá el resumen de la noticia.",
        ),
    ],
    # title="⚔️ Clickbait Fighter! ⚔️",
    thumbnail="https://huggingface.co/spaces/Iker/ClickbaitFighter/resolve/main/logo2.png",
    theme="JohnSmith9982/small_and_pretty",
    description="""
<table>
<tr>   
<td style="width:100%"><img src="https://huggingface.co/spaces/Iker/ClickbaitFighter/resolve/main/head.png" align="right" width="100%"> </td>
</tr>
</table>

<p align="justify">Esta Inteligencia Artificial es capaz de generar un resumen de una sola frase que revela la verdad detrás de un titular sensacionalista o clickbait. Solo tienes que introducir la URL de la noticia. La IA accederá a la noticia, la leerá y en cuestión de segundos generará un resumen de una sola frase que revele la verdad detrás del titular.</p>
   
   🎚 Ajusta el nivel de resumen con el control deslizante. Cuanto maś alto, más corto será el resumen.
   
   ⌚ La IA se encuentra corriendo en un hardware bastante modesto, debería tardar menos de 30 segundos en generar el resumen, pero si muchos usuarios usan la app a la vez, tendrás que esperar tu turno.
   
   💸 Este es un projecto sin ánimo de lucro, no se genera ningún tipo de ingreso con esta app. Los datos, la IA y el código se publicarán para su uso en la investigación académica. No puedes usar esta app para ningún uso comercial.
   
   🧪 El modelo se encuentra en fase de desarrollo, si quieres ayudar a mejorarlo puedes usar los botones 👍 y 👎 para valorar el resumen. ¡Gracias por tu ayuda!""",
    article="Esta Inteligencia Artificial ha sido generada por Iker García-Ferrero. Puedes saber más sobre mi trabajo en mi [página web](https://ikergarcia1996.github.io/Iker-Garcia-Ferrero/) o mi perfil de [X](https://twitter.com/iker_garciaf). Puedes ponerte en contacto conmigo a través de correo electrónico (ver web) y X.",
    cache_examples=False,
    allow_flagging="manual",
    flagging_options=[("👍", "correct"), ("👎", "incorrect")],
    flagging_callback=hf_writer,
    concurrency_limit=20,
)


demo.queue(max_size=None)
demo.launch(share=False)