import subprocess
import os
import torch

import gradio as gr
import os
import spaces
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from threading import Thread
from transformers.utils.import_utils import _is_package_available

# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN", None)


DESCRIPTION = """
# [MInference 1.0: Accelerating Pre-filling for Long-Context LLMs via Dynamic Sparse Attention](https://aka.ms/MInference) (NeurIPS'24 Spotlight)

_Huiqiang Jiang†, Yucheng Li†, Chengruidong Zhang†, Qianhui Wu, Xufang Luo, Surin Ahn, Zhenhua Han, Amir H. Abdi, Dongsheng Li, Chin-Yew Lin, Yuqing Yang and Lili Qiu_

<h3 style="text-align: center;"><a href="https://github.com/microsoft/MInference" target="blank"> [Code]</a> 
<a href="https://aka.ms/MInference" target="blank"> [Project Page]</a>
<a href="https://arxiv.org/abs/2407.02490" target="blank"> [Paper]</a></h3>

## News
- 🧤 [24/09/26] MInference has been accepted as **spotlight** at **NeurIPS'24**. See you in Vancouver!
- 👘 [24/09/16] We are pleased to announce the release of our KV cache offloading work, [RetrievalAttention](https://aka.ms/RetrievalAttention), which accelerates long-context LLM inference via vector retrieval.
- 🥤 [24/07/24] MInference support [meta-llama/Meta-Llama-3.1-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-8B-Instruct) now.
- 🪗 [24/07/07] Thanks @AK for sponsoring. You can now use MInference online in the [HF Demo](https://huggingface.co/spaces/microsoft/MInference) with ZeroGPU.
- 📃 [24/07/03] Due to an issue with arXiv, the PDF is currently unavailable there. You can find the paper at this [link](https://export.arxiv.org/pdf/2407.02490).
- 🧩 [24/07/03] We will present **MInference 1.0** at the _**Microsoft Booth**_ and _**ES-FoMo**_ at ICML'24. See you in Vienna!

<font color="brown"><b>This is only a deployment demo. You can follow the code below to try MInference locally.</b></font>

```bash
git clone https://huggingface.co/spaces/microsoft/MInference
cd MInference
pip install -r requirments.txt
pip install flash_attn pycuda==2023.1
python app.py
```
"""

LICENSE = """
<div style="text-align: center;">
    <p>© 2024 Microsoft</p>
</div>
"""

PLACEHOLDER = """
<div style="padding: 30px; text-align: center; display: flex; flex-direction: column; align-items: center;">
   <h1 style="font-size: 28px; margin-bottom: 2px; opacity: 0.55;">LLaMA-3-8B-Gradient-1M w/ MInference</h1>
   <p style="font-size: 18px; margin-bottom: 2px; opacity: 0.65;">Ask me anything...</p>
</div>
"""


css = """
h1 {
  text-align: center;
  display: block;
}
"""

# Load the tokenizer and model
model_name = "gradientai/Llama-3-8B-Instruct-Gradient-1048k" if torch.cuda.is_available() else "Qwen/Qwen2-0.5B"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(
    model_name, torch_dtype="auto", device_map="auto"
)  # to("cuda:0")

if torch.cuda.is_available() and _is_package_available("pycuda"):
    from minference import MInference

    minference_patch = MInference("minference", model_name)
    model = minference_patch(model)

terminators = [tokenizer.eos_token_id, tokenizer.convert_tokens_to_ids("<|eot_id|>")]


@spaces.GPU(duration=120)
def chat_llama3_8b(
    message: str, history: list, temperature: float, max_new_tokens: int
) -> str:
    """
    Generate a streaming response using the llama3-8b model.
    Args:
        message (str): The input message.
        history (list): The conversation history used by ChatInterface.
        temperature (float): The temperature for generating the response.
        max_new_tokens (int): The maximum number of new tokens to generate.
    Returns:
        str: The generated response.
    """
    # global model
    conversation = []
    for user, assistant in history:
        conversation.extend(
            [
                {"role": "user", "content": user},
                {"role": "assistant", "content": assistant},
            ]
        )
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt").to(
        model.device
    )

    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )

    generate_kwargs = dict(
        input_ids=input_ids,
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        temperature=temperature,
        eos_token_id=terminators,
    )
    # This will enforce greedy generation (do_sample=False) when the temperature is passed 0, avoiding the crash.
    if temperature == 0:
        generate_kwargs["do_sample"] = False

    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        # print(outputs)
        yield "".join(outputs)


# Gradio block
chatbot = gr.Chatbot(height=450, placeholder=PLACEHOLDER, label="Gradio ChatInterface")

with gr.Blocks(fill_height=True, css=css) as demo:

    gr.Markdown(DESCRIPTION)
    gr.ChatInterface(
        fn=chat_llama3_8b,
        chatbot=chatbot,
        fill_height=True,
        additional_inputs_accordion=gr.Accordion(
            label="⚙️ Parameters", open=False, render=False
        ),
        additional_inputs=[
            gr.Slider(
                minimum=0,
                maximum=1,
                step=0.1,
                value=0.95,
                label="Temperature",
                render=False,
            ),
            gr.Slider(
                minimum=128,
                maximum=4096,
                step=1,
                value=512,
                label="Max new tokens",
                render=False,
            ),
        ],
        examples=[
            ["How to setup a human base on Mars? Give short answer."],
            ["Explain theory of relativity to me like I’m 8 years old."],
            ["What is 9,000 * 9,000?"],
            ["Write a pun-filled happy birthday message to my friend Alex."],
            ["Justify why a penguin might make a good king of the jungle."],
        ],
        cache_examples=False,
    )

    gr.Markdown(LICENSE)

if __name__ == "__main__":
    demo.launch(share=False)