import os
from threading import Thread
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

MAX_MAX_NEW_TOKENS = 1024
DEFAULT_MAX_NEW_TOKENS = 256
MAX_INPUT_TOKEN_LENGTH = 512

DESCRIPTION = """\
# OpenELM-3B-Instruct

This Space demonstrates [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) by Apple. Please, check the original model card for details.
You can see the other models of the OpenELM family [here](https://huggingface.co/apple/OpenELM)
The following Colab notebooks are available:
* [OpenELM-3B-Instruct (GPU)](https://gist.github.com/Norod/4f11bb36bea5c548d18f10f9d7ec09b0)
* [OpenELM-270M (CPU)](https://gist.github.com/Norod/5a311a8e0a774b5c35919913545b7af4)

You might also be interested in checking out Apple's [CoreNet Github page](https://github.com/apple/corenet?tab=readme-ov-file).

# Note: Use this model only for completing sentences and instruction following.

If you duplicate this space, make sure that:
1. You have access to [meta-llama/Llama-2-7b-hf](https://huggingface.co/meta-llama/Llama-2-7b-hf) because this model uses it as a tokenizer.
2. You have set a HF_TOKEN secret with a "read" token so this space could access Llama-2-7b-hf on your behalf
"""

LICENSE = """
<p/>

---
As a derivative work of [OpenELM-3B-Instruct](https://huggingface.co/apple/OpenELM-3B-Instruct) by Apple,
this demo is governed by the original [license](https://huggingface.co/apple/OpenELM-3B-Instruct/blob/main/LICENSE).
"""

if not torch.cuda.is_available():
    DESCRIPTION += "\n<p>Running on CPU 🥶 This demo does not work on CPU.</p>"


if torch.cuda.is_available():
    model_id = "apple/OpenELM-3B-Instruct"    
    model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto", trust_remote_code=True, low_cpu_mem_usage=True)
    tokenizer_id = "meta-llama/Llama-2-7b-hf"
    tokenizer = AutoTokenizer.from_pretrained(tokenizer_id)
    if tokenizer.pad_token == None:
        tokenizer.pad_token = tokenizer.eos_token
        tokenizer.pad_token_id = tokenizer.eos_token_id
        model.config.pad_token_id = tokenizer.eos_token_id

@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.6,
    top_p: float = 0.9,
    top_k: int = 50,
    repetition_penalty: float = 1.4,
) -> Iterator[str]:
    
    historical_text = ""
    #Prepend the entire chat history to the message with new lines between each message
    for user, assistant in chat_history:
        historical_text += f"\n{user}\n{assistant}"
        
    if len(historical_text) > 0:
        message = historical_text + f"\n{message}"
    input_ids = tokenizer([message], return_tensors="pt").input_ids
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        pad_token_id = tokenizer.eos_token_id,
        repetition_penalty=repetition_penalty,
        no_repeat_ngram_size=5,
        early_stopping=False,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = gr.ChatInterface(
    fn=generate,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.1,
            maximum=4.0,
            step=0.1,
            value=0.6,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.4,
        ),
    ],
    stop_btn=None,
    examples=[
        ["A recipe for a chocolate cake:"],
        ["Can you explain briefly to me what is the Python programming language?"],
        ["Once upon a time there was a little"],
        ["Question: What is the capital of France?\nAnswer:"],
        ["Question: I am very tired, what should I do?\nAnswer:"],
    ],
)

with gr.Blocks(css="style.css") as demo:
    gr.Markdown(DESCRIPTION)
    gr.DuplicateButton(value="Duplicate Space for private use", elem_id="duplicate-button")
    chat_interface.render()
    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.queue(max_size=20).launch()