import gradio as gr
import spaces

# from huggingface_hub import InferenceClient
from transformers import pipeline
from transformers import AutoModelForCausalLM, AutoTokenizer, TextStreamer
import os
HF_TOKEN = os.getenv('HF_TOKEN')

checkpoint = "zidsi/SLlamica_PT4SFT_v2"
device = "cuda"  # "cuda" or "cpu"
tokenizer = AutoTokenizer.from_pretrained(checkpoint,token=HF_TOKEN)
model = AutoModelForCausalLM.from_pretrained(checkpoint,token=HF_TOKEN)
model.to(device)

@spaces.GPU
def predict(message, history,max_new_tokens,temperature,top_p):
    history.append({"role": "user", "content": message})
    input_text = tokenizer.apply_chat_template(history, tokenize=False)
    inputs = tokenizer.encode(input_text, return_tensors="pt").to(device) 
    # Use TextStreamer for streaming response
    # streamer = TextStreamer(tokenizer)
    outputs = model.generate(inputs, max_new_tokens=max_new_tokens, temperature=temperature, top_p=top_p, do_sample=True)
    # Despite returning the usual output, the streamer will also print the generated text to stdout.

    decoded = tokenizer.decode(outputs[0])
    response = decoded.split("[INST]")[-1].split("[/INST]")[-1]
    return response

"""
For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
"""
demo = gr.ChatInterface(
    predict, type="messages",
additional_inputs=[
        gr.Slider(minimum=1, maximum=2048, value=256, step=1, label="Max new tokens"),
        gr.Slider(minimum=0.1, maximum=1.0, value=0.8, step=0.05, label="Temperature"),
        gr.Slider(
            minimum=0.1,
            maximum=1.0,
            value=0.95,
            step=0.01,
            label="Top-p (nucleus sampling)",
        ),
    ],    
)


if __name__ == "__main__":
    demo.launch()