import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import (
    AutoModelForCausalLM,
    AutoTokenizer,
    TextIteratorStreamer,
    LlamaTokenizer,
)

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 50
MAX_INPUT_TOKEN_LENGTH = 512

DESCRIPTION = """\
# OpenELM-270M-Instruct -- Running on CPU

This Space demonstrates [apple/OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) by Apple. Please, check the original model card for details.

For additional detail on the model, including a link to the arXiv paper, refer to the [Hugging Face Paper page for OpenELM](https://huggingface.co/papers/2404.14619) .

For details on pre-training, instruction tuning, and parameter-efficient finetuning for the model refer to the [OpenELM page in the CoreNet GitHub repository](https://github.com/apple/corenet/tree/main/projects/openelm) .
"""

LICENSE = """
<p/>

---
As a derivative work of [apple/OpenELM-270M-Instruct](https://huggingface.co/apple/OpenELM-270M-Instruct) by Apple,
this demo is governed by the original [license](https://huggingface.co/apple/OpenELM-270M-Instruct/blob/main/LICENSE)

Based on the [Norod78/OpenELM_3B_Demo](https://huggingface.co/spaces/Norod78/OpenELM_3B_Demo) space.
"""


model = AutoModelForCausalLM.from_pretrained(
    "apple/OpenELM-270M-Instruct",
    revision="eb111ff",    
    trust_remote_code=True,
)
tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    revision="01c7f73",
    trust_remote_code=True,
    tokenizer_class=LlamaTokenizer,
)

if tokenizer.pad_token == None:
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id
    model.config.pad_token_id = tokenizer.eos_token_id

def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.1,
    top_p: float = 0.5,
    top_k: int = 3,
    repetition_penalty: float = 1.4,
) -> Iterator[str]:
    
    historical_text = ""
    #Prepend the entire chat history to the message with new lines between each message
    for user, assistant in chat_history:
        historical_text += f"\n{user}\n{assistant}"
        
    if len(historical_text) > 0:
        message = historical_text + f"\n{message}"
    input_ids = tokenizer([message], return_tensors="pt").input_ids
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        pad_token_id = tokenizer.eos_token_id,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=5,
        early_stopping=False,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.0,
            maximum=4.0,
            step=0.1,
            value=0.1,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.5,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=3,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.4,
        ),
    ],
    stop_btn="Stop",
    cache_examples=False,
    examples=[
        ["You are three years old.  Count from one to ten."],
        ["Explain quantum physics in 5 words or less:"],
        ["Question: What do you call a bear with no teeth?\nAnswer:"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    chat_interface.render()
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()