from threading import Thread
from typing import Iterator

from spaces import GPU, config
from transformers import TextIteratorStreamer

from .loader import get_loader


@GPU
def generate(
    message: str,
    chat_history: list[dict[str, str]],
    system_message="",
    model="Qwen/Qwen2.5-0.5B-Instruct",
    max_tokens=512,
    temperature=0.6,
    repetition_penalty=1.2,
    top_p=0.9,
    top_k=50,
) -> Iterator[str]:
    # Prepend system prompt
    if not chat_history or chat_history[0].get("role") != "system":
        chat_history.insert(0, {"role": "system", "content": system_message})
    else:
        chat_history[0]["content"] = system_message

    # Append user message before generating
    chat_history.append({"role": "user", "content": message})

    yield from transformers_generate(
        chat_history,
        model,
        max_tokens,
        temperature,
        repetition_penalty,
        top_p,
        top_k,
    )


def transformers_generate(
    chat_history: list[dict[str, str]],
    model: str,
    max_tokens: int,
    temperature: float,
    repetition_penalty: float,
    top_p: float,
    top_k: int,
) -> Iterator[str]:
    loader = get_loader(singleton=not config.Config.zero_gpu)
    loader.load(model)

    llm = loader.llm
    tokenizer = loader.tokenizer

    # Handle models that don't have a padding token
    if tokenizer.pad_token_id is None:
        tokenizer.pad_token_id = tokenizer.eos_token_id

    # https://huggingface.co/docs/transformers/main/en/internal/tokenization_utils#transformers.PreTrainedTokenizerBase.apply_chat_template
    results = tokenizer.apply_chat_template(
        chat_history,
        tokenize=True,
        return_dict=True,  # get the attention mask
        return_tensors="pt",
        # https://huggingface.co/docs/transformers/chat_templating#what-are-generation-prompts
        add_generation_prompt=True,
    )

    input_ids = results["input_ids"].to(llm.device)
    attention_mask = results["attention_mask"].to(llm.device)

    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
    )

    # https://huggingface.co/blog/how-to-generate
    generate_kwargs = dict(
        do_sample=True,
        streamer=streamer,
        input_ids=input_ids,
        attention_mask=attention_mask,
        pad_token_id=tokenizer.pad_token_id,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        max_new_tokens=max_tokens,
        repetition_penalty=repetition_penalty,
    )

    # Stream text off the main thread
    t = Thread(target=llm.generate, kwargs=generate_kwargs)
    t.start()

    # Collect output tokens
    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)