from io import BytesIO

import string
import gradio as gr
import requests
from utils import Endpoint


def encode_image(image):
    buffered = BytesIO()
    image.save(buffered, format="JPEG")
    buffered.seek(0)

    return buffered


def query_chat_api(
    image, prompt, decoding_method, temperature, len_penalty, repetition_penalty
):

    url = endpoint.url

    headers = {"User-Agent": "BLIP-2 HuggingFace Space"}

    data = {
        "prompt": prompt,
        "use_nucleus_sampling": decoding_method == "Nucleus sampling",
        "temperature": temperature,
        "length_penalty": len_penalty,
        "repetition_penalty": repetition_penalty,
    }

    image = encode_image(image)
    files = {"image": image}

    response = requests.post(url, data=data, files=files, headers=headers)

    if response.status_code == 200:
        return response.json()
    else:
        return "Error: " + response.text


def query_caption_api(
    image, decoding_method, temperature, len_penalty, repetition_penalty
):

    url = endpoint.url
    # replace /generate with /caption
    url = url.replace("/generate", "/caption")

    headers = {"User-Agent": "BLIP-2 HuggingFace Space"}

    data = {
        "use_nucleus_sampling": decoding_method == "Nucleus sampling",
        "temperature": temperature,
        "length_penalty": len_penalty,
        "repetition_penalty": repetition_penalty,
    }

    image = encode_image(image)
    files = {"image": image}

    response = requests.post(url, data=data, files=files, headers=headers)

    if response.status_code == 200:
        return response.json()
    else:
        return "Error: " + response.text


def postprocess_output(output):
    # if last character is not a punctuation, add a full stop
    if not output[0][-1] in string.punctuation:
        output[0] += "."

    return output


def inference_chat(
    image,
    text_input,
    decoding_method,
    temperature,
    length_penalty,
    repetition_penalty,
    history=[],
):
    text_input = text_input
    history.append(text_input)

    prompt = " ".join(history)
    print(prompt)

    output = query_chat_api(
        image, prompt, decoding_method, temperature, length_penalty, repetition_penalty
    )
    output = postprocess_output(output)
    history += output

    chat = [
        (history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)
    ]  # convert to tuples of list

    return {chatbot: chat, state: history}


def inference_caption(
    image,
    decoding_method,
    temperature,
    length_penalty,
    repetition_penalty,
):
    output = query_caption_api(
        image, decoding_method, temperature, length_penalty, repetition_penalty
    )

    return output[0]


title = """<h1 align="center">BLIP-2</h1>"""
description = """Gradio demo for BLIP-2, a multimodal chatbot from Salesforce Research. To use it, simply upload your image, or click one of the examples to load them. Please visit our <a href='https://github.com/salesforce/LAVIS/tree/main/projects/blip2' target='_blank'>project webpage</a>.</p> 
<p> <strong>Disclaimer</strong>: This is a research prototype and is not intended for production use. No data including but not restricted to text and images is collected. </p>"""
article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2201.12086' target='_blank'>BLIP-2: Bootstrapping Language-Image Pre-training with Frozen Image Encoders and Large Language Models</a>"

endpoint = Endpoint()

examples = [
    ["house.png", "How could someone get out of the house?"],
    # [
    #     "sunset.png",
    #     "Write a romantic message that goes along this photo.",
    # ],
]

with gr.Blocks() as iface:
    state = gr.State([])

    gr.Markdown(title)
    gr.Markdown(description)
    gr.Markdown(article)
    with gr.Row():
        with gr.Column():
            image_input = gr.Image(type="pil")

            with gr.Row():
                sampling = gr.Radio(
                    choices=["Beam search", "Nucleus sampling"],
                    value="Beam search",
                    label="Text Decoding Method",
                    interactive=True,
                )

                temperature = gr.Slider(
                    minimum=0.5,
                    maximum=1.0,
                    value=0.8,
                    interactive=True,
                    label="Temperature (set to 0 for greedy decoding with nucleus sampling)",
                )

                len_penalty = gr.Slider(
                    minimum=-2.0,
                    maximum=2.0,
                    value=1.0,
                    step=0.5,
                    interactive=True,
                    label="Length Penalty (larger value encourages longer sequence with beam search)",
                )

                rep_penalty = gr.Slider(
                    minimum=1.0,
                    maximum=5.0,
                    value=1.5,
                    step=0.5,
                    interactive=True,
                    label="Repeat Penalty (larger value prevents repetition)",
                )

            with gr.Row():
                caption_output = gr.Textbox(lines=2, label="Caption Output")
                caption_button = gr.Button(
                    value="Caption it!", interactive=True, variant="primary"
                )
                caption_button.click(
                    inference_caption,
                    [
                        image_input,
                        sampling,
                        temperature,
                        len_penalty,
                        rep_penalty,
                    ],
                    [caption_output],
                )

        with gr.Column():
            chat_input = gr.Textbox(lines=2, label="Chat Input")

            with gr.Row():
                chatbot = gr.Chatbot()
                image_input.change(lambda: (None, "", "", []), [], [chatbot, chat_input, caption_output, state])

            with gr.Row():

                clear_button = gr.Button(value="Clear", interactive=True)
                clear_button.click(
                    lambda: ("", None, [], []),
                    [],
                    [chat_input, image_input, chatbot, state],
                )

                submit_button = gr.Button(
                    value="Submit", interactive=True, variant="primary"
                )
                submit_button.click(
                    inference_chat,
                    [
                        image_input,
                        chat_input,
                        sampling,
                        temperature,
                        len_penalty,
                        rep_penalty,
                        state,
                    ],
                    [chatbot, state],
                )

    examples = gr.Examples(
        examples=examples,
        inputs=[image_input, chat_input, sampling, temperature, len_penalty, rep_penalty, state],
        outputs=[chatbot, state],
        run_on_click=True,
        fn = inference_chat,
    )

iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)