import gradio as gr
import pprint
import subprocess
from threading import Thread
from transformers import AutoTokenizer, TextIteratorStreamer
from transformers import AutoTokenizer, AutoModelForCausalLM

result = subprocess.run(["lscpu"], text=True, capture_output=True)
pprint.pprint(result.stdout)


tokenizer = AutoTokenizer.from_pretrained("suriya7/Gemma-2b-SFT")
model = AutoModelForCausalLM.from_pretrained("suriya7/Gemma-2b-SFT")


def run_generation(user_text, top_p, temperature, top_k, max_new_tokens):
    alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

    ### Instruction:
    {}
    
    ### Input:
    {}
    
    ### Response:
    {}"""

    inputs = tokenizer(
    [
    alpaca_prompt.format(
        "You are an AI assistant. Please ensure that the answers conclude with an end-of-sequence (EOS) token.", # instruction
        user_text, # input goes here
        "", # output - leave this blank for generation!
    )
    ], return_tensors = "pt")

    # Start generation on a separate thread, so that we don't block the UI. The text is pulled from the streamer
    # in the main thread. Adds timeout to the streamer to handle exceptions in the generation thread.
    streamer = TextIteratorStreamer(
        tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True
    )
    generate_kwargs = dict(
        inputs,
        streamer=streamer,
        max_new_tokens=250,
        do_sample=True,
        repetition_penalty=1.5,
        temperature=0.7,
        top_k=2,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    # Pull the generated text from the streamer, and update the model output.
    model_output = ""
    for new_text in streamer:
        model_output += new_text
        yield model_output
    return model_output

def reset_textbox():
    return gr.update(value="")


with gr.Blocks() as demo:

    with gr.Row():
        with gr.Column(scale=4):
            user_text = gr.Textbox(
                label="User input",
            )
            model_output = gr.Textbox(label="Model output", lines=10, interactive=False)
            button_submit = gr.Button(value="Submit")

        with gr.Column(scale=1):
            max_new_tokens = gr.Slider(
                minimum=1,
                maximum=1000,
                value=250,
                step=1,
                interactive=True,
                label="Max New Tokens",
            )
        
            top_k = gr.Slider(
                minimum=1,
                maximum=50,
                value=50,
                step=1,
                interactive=True,
                label="Top-k",
            )
            temperature = gr.Slider(
                minimum=0.1,
                maximum=5.0,
                value=0.8,
                step=0.1,
                interactive=True,
                label="Temperature",
            )

    user_text.submit(
        run_generation,
        [user_text],
         model_output,
    )
    button_submit.click(
        run_generation,
        [user_text],
        model_output,
    )

    demo.queue(max_size=32).launch(server_name="0.0.0.0")
    # For local use:
    # demo.launch(server_name="0.0.0.0")