import gradio as gr
from transformers import AutoProcessor, AutoModelForVision2Seq
import re
import time
from PIL import Image
import torch
import spaces
import subprocess
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)


processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct", 
        torch_dtype=torch.bfloat16,
        #_attn_implementation="flash_attention_2"
        ).to("cuda")

@spaces.GPU
def model_inference(
    images, text, assistant_prefix, decoding_strategy, temperature, max_new_tokens,
    repetition_penalty, top_p
):
    if text == "" and not images:
        gr.Error("Please input a query and optionally image(s).")

    if text == "" and images:
        gr.Error("Please input a text query along the image(s).")

    if isinstance(images, Image.Image):
        images = [images]


    resulting_messages = [
                {
                    "role": "user",
                    "content": [{"type": "image"}] + [
                        {"type": "text", "text": text}
                    ]
                }
            ]

    if assistant_prefix:
      text = f"{assistant_prefix} {text}"


    prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
    inputs = processor(text=prompt, images=[images], return_tensors="pt")
    inputs = {k: v.to("cuda") for k, v in inputs.items()}

    generation_args = {
        "max_new_tokens": max_new_tokens,
        "repetition_penalty": repetition_penalty,

    }

    assert decoding_strategy in [
        "Greedy",
        "Top P Sampling",
    ]
    if decoding_strategy == "Greedy":
        generation_args["do_sample"] = False
    elif decoding_strategy == "Top P Sampling":
        generation_args["temperature"] = temperature
        generation_args["do_sample"] = True
        generation_args["top_p"] = top_p

    generation_args.update(inputs)

    # Generate
    generated_ids = model.generate(**generation_args)

    generated_texts = processor.batch_decode(generated_ids[:, generation_args["input_ids"].size(1):], skip_special_tokens=True)
    return generated_texts[0]


with gr.Blocks(fill_height=False) as demo:
    gr.Markdown("## SmolVLM: Small yet Mighty 💫")
    gr.Markdown("Play with [HuggingFaceTB/SmolVLM-Instruct](https://huggingface.co/HuggingFaceTB/SmolVLM-Instruct) in this demo. To get started, upload an image and text or try one of the examples.")
    with gr.Column():
        with gr.Row():
            image_input = gr.Image(label="Upload your Image", type="pil")

            with gr.Column():
                query_input = gr.Textbox(label="Prompt")
                assistant_prefix = gr.Textbox(label="Assistant Prefix", placeholder="Let's think step by step.")

                submit_btn = gr.Button("Submit")
        output = gr.Textbox(label="Output")

        
        with gr.Accordion(label="Advanced Generation Parameters", open=False):
            examples=[
                    ["example_images/rococo.jpg", "What art era is this?", "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
                    ["example_images/examples_wat_arun.jpg", "I'm planning a visit to this temple, give me travel tips.",  "", "Greedy", 0.4, 512, 1.2, 0.8],
                    ["example_images/examples_invoice.png", "What is the due date and the invoice date?",  "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
                    ["example_images/s2w_example.png", "What is this UI about?",  "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
                    ["example_images/examples_weather_events.png", "Where do the severe droughts happen according to this diagram?",  "", "Top P Sampling", 0.4, 512, 1.2, 0.8],
            ]
            # Hyper-parameters for generation
            max_new_tokens = gr.Slider(
                minimum=8,
                maximum=1024,
                value=512,
                step=1,
                interactive=True,
                label="Maximum number of new tokens to generate",
            )
            repetition_penalty = gr.Slider(
                minimum=0.01,
                maximum=5.0,
                value=1.2,
                step=0.01,
                interactive=True,
                label="Repetition penalty",
                info="1.0 is equivalent to no penalty",
            )
            temperature = gr.Slider(
                minimum=0.0,
                maximum=5.0,
                value=0.4,
                step=0.1,
                interactive=True,
                label="Sampling temperature",
                info="Higher values will produce more diverse outputs.",
            )
            top_p = gr.Slider(
                minimum=0.01,
                maximum=0.99,
                value=0.8,
                step=0.01,
                interactive=True,
                label="Top P",
                info="Higher values is equivalent to sampling more low-probability tokens.",
            )
            decoding_strategy = gr.Radio(
                [
                    "Top P Sampling",
                    "Greedy",
                    
                ],
                value="Top P Sampling",
                label="Decoding strategy",
                interactive=True,
                info="Higher values is equivalent to sampling more low-probability tokens.",
            )
            decoding_strategy.change(
                fn=lambda selection: gr.Slider(
                    visible=(
                        selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
                    )
                ),
                inputs=decoding_strategy,
                outputs=temperature,
            )

            decoding_strategy.change(
                fn=lambda selection: gr.Slider(
                    visible=(
                        selection in ["contrastive_sampling", "beam_sampling", "Top P Sampling", "sampling_top_k"]
                    )
                ),
                inputs=decoding_strategy,
                outputs=repetition_penalty,
            )
            decoding_strategy.change(
                fn=lambda selection: gr.Slider(visible=(selection in ["Top P Sampling"])),
                inputs=decoding_strategy,
                outputs=top_p,
            )
        gr.Examples(
                        examples = examples,
                        inputs=[image_input, query_input, assistant_prefix, decoding_strategy, temperature,
                                                              max_new_tokens, repetition_penalty, top_p],
                        outputs=output,
                        fn=model_inference
                    )   
        

        submit_btn.click(model_inference, inputs = [image_input, query_input, assistant_prefix, decoding_strategy, temperature,
                                                      max_new_tokens, repetition_penalty, top_p], outputs=output)


demo.launch(debug=True)