import os
import torch
from threading import Thread
from typing import Iterator
import gradio as gr
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

DESCRIPTION = """\
# Llama 3.2 1B Instruct
Llama 3.2 1B is Meta's latest iteration of open LLMs.
This is a demo of [`meta-llama/Llama-3.2-3B-Instruct`](https://huggingface.co/meta-llama/Llama-3.2-3B-Instruct), fine-tuned for instruction following.
For more details, please check [our post](https://huggingface.co/blog/llama32).
"""

# Model setup
model_id = "ussipan/SipanGPT-0.3-Llama-3.2-1B-GGUF"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(
    model_id,
    device_map="auto",
    torch_dtype=torch.bfloat16,
)
model.eval()

def generate(
    message: str,
    chat_history: list,
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
) -> Iterator[str]:
    conversation = chat_history + [{"role": "user", "content": message}]

    input_ids = tokenizer.apply_chat_template(
        conversation,
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(model.device)

    streamer = TextIteratorStreamer(
        tokenizer,
        timeout=20.0,
        skip_prompt=True,
        skip_special_tokens=True
    )

    generation_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        temperature=temperature,
        do_sample=True,
    )

    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()

    conversation.append({"role": "assistant", "content": ""})
    output = []

    for text in streamer:
        output.append(text)
        conversation[-1]["content"] = "".join(output)
        yield "", conversation

def handle_like(data: gr.LikeData):
    print(f"El mensaje {data.index} fue puntuado como {'bueno' if data.liked else 'malo'}.")

class SipanGPTTheme(Base):
    def __init__(
        self,
        *,
        primary_hue: colors.Color | str = colors.Color(
            name="custom_green",
            c50="#f0fde4",
            c100="#e1fbc8",
            c200="#c3f789",
            c300="#a5f34a",
            c400="#7dfa00",  # primary color
            c500="#5ef000",
            c600="#4cc700",
            c700="#39a000",
            c800="#2b7900",
            c900="#1d5200",
            c950="#102e00",
        ),
        secondary_hue: colors.Color | str = colors.Color(
            name="custom_secondary_green",
            c50="#edfce0",
            c100="#dbf9c1",
            c200="#b7f583",
            c300="#93f145",
            c400="#5fed00",  # secondary color
            c500="#4ed400",
            c600="#3fad00",
            c700="#308700",
            c800="#236100",
            c900="#153b00",
            c950="#0a1f00",
        ),
        neutral_hue: colors.Color | str = colors.gray,
        spacing_size: sizes.Size | str = sizes.spacing_md,
        radius_size: sizes.Size | str = sizes.radius_md,
        text_size: sizes.Size | str = sizes.text_md,
        font: fonts.Font | str | list[fonts.Font | str] = [
            fonts.GoogleFont("Exo 2"),
            "ui-sans-serif",
            "system-ui",
            "sans-serif",
        ],
        font_mono: fonts.Font | str | list[fonts.Font | str] = [
            fonts.GoogleFont("Fraunces"),
            "ui-monospace",
            "monospace",
        ],
    ):
        super().__init__(
            primary_hue=primary_hue,
            secondary_hue=secondary_hue,
            neutral_hue=neutral_hue,
            spacing_size=spacing_size,
            radius_size=radius_size,
            text_size=text_size,
            font=font,
            font_mono=font_mono,
        )
        self.set(
            # Light mode settings
            body_background_fill="*neutral_50",
            body_text_color="*neutral_900",
            color_accent_soft="*secondary_200",
            button_primary_background_fill="*primary_600",
            button_primary_background_fill_hover="*primary_500",
            button_primary_text_color="*neutral_50",
            block_title_text_color="*primary_600",
            input_background_fill="*neutral_200",
            input_border_color="*neutral_300",
            input_placeholder_color="*neutral_500",
            block_background_fill="*neutral_100",
            block_label_background_fill="*primary_100",
            block_label_text_color="*neutral_800",
            checkbox_background_color="*neutral_200",
            checkbox_border_color="*primary_500",
            loader_color="*primary_500",
            slider_color="*primary_500",

            # Dark mode settings
            body_background_fill_dark="*neutral_900",
            body_text_color_dark="*neutral_50",
            color_accent_soft_dark="*secondary_800",
            button_primary_background_fill_dark="*primary_700",
            button_primary_background_fill_hover_dark="*primary_600",
            button_primary_text_color_dark="*neutral_950",
            block_title_text_color_dark="*primary_400",
            input_background_fill_dark="*neutral_800",
            input_border_color_dark="*neutral_700",
            input_placeholder_color_dark="*neutral_400",
            block_background_fill_dark="*neutral_850",
            block_label_background_fill_dark="*primary_900",
            block_label_text_color_dark="*neutral_200",
            checkbox_background_color_dark="*neutral_800",
            checkbox_border_color_dark="*primary_600",
            loader_color_dark="*primary_400",
            slider_color_dark="*primary_600",
        )

theme = SipanGPTTheme()

with gr.Blocks(theme=theme, fill_height=True)  as demo:
    chatbot = gr.Chatbot(
        label="SipánGPT 0.3 Llama 3.2",
        examples=[{"text": "Que carreras existen en la uss?"}, {"text": "Quien es el decano de la facultad de ingenieria?"}, {"text": "Que maestrias tiene la universidad?"}],
        value=[],
        show_label=True,
        type="messages",
        bubble_full_width=False,
        placeholder = PLACEHOLDER,
    )

    msg = gr.Textbox(
        show_label=False,
        placeholder="Escribe tu pregunta aquí...",
        scale=4
    )

    with gr.Row():
        submit = gr.Button("Enviar")
        clear = gr.ClearButton([msg, chatbot])

    with gr.Accordion("Parameters", open=False):
        temperature = gr.Slider(
            minimum=0.1,
            maximum=2.0,
            value=0.6,
            step=0.1,
            label="Temperatura",
        )
        max_new_tokens = gr.Slider(
            minimum=1,
            maximum=2048,
            value=1024,
            step=1,
            label="Máximo de nuevos Tokens",
        )

    msg.submit(generate, [msg, chatbot, max_new_tokens, temperature], [msg, chatbot])
    submit.click(generate, [msg, chatbot, max_new_tokens, temperature], [msg, chatbot])
    chatbot.like(handle_like)

demo.launch()