import string
import gradio as gr
import requests
import torch


from transformers import BlipForQuestionAnswering, BlipProcessor

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

processor = BlipProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
model_vqa = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large").to(device)

from transformers import BlipProcessor, BlipForConditionalGeneration

cap_processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
cap_model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")



def caption(input_image):
    inputs = cap_processor(input_image, return_tensors="pt")
    inputs["num_beams"] = 1
    inputs['num_return_sequences'] =1
    out = cap_model.generate(**inputs)
    return "\n".join(cap_processor.batch_decode(out, skip_special_tokens=True))
import openai
import os
openai.api_key= os.getenv('openai_appkey') 
def gpt3(question,vqa_answer,caption):
    prompt=caption+"\n"+question+"\n"+vqa_answer+"\n Tell me the right answer."
    response = openai.Completion.create(
    engine="text-davinci-003",
    prompt=prompt,
    max_tokens=10,
    n=1,
    stop=None,
    temperature=0.7,
    )
    answer = response.choices[0].text.strip()
    # return "input_text:\n"+prompt+"\n\n output_answer:\n"+answer
    return answer

    
def inference_chat(input_image,input_text):
    cap=caption(input_image)
    inputs = processor(images=input_image, text=input_text,return_tensors="pt")
    inputs["max_length"] = 10
    inputs["num_beams"] = 5
    inputs['num_return_sequences'] =4
    out = model_vqa.generate(**inputs)
    out=processor.batch_decode(out, skip_special_tokens=True)
    vqa="\n".join(out)
    gpt3_out=gpt3(input_text,vqa,cap)
    gpt3_out1=gpt3(input_text,'',cap)
    return out[0], gpt3_out,gpt3_out1
    
with gr.Blocks(
    css="""
    .message.svelte-w6rprc.svelte-w6rprc.svelte-w6rprc {font-size: 20px; margin-top: 20px}
    #component-21 > div.wrap.svelte-w6rprc {height: 600px;}
    """
) as iface:
    state = gr.State([])
    #caption_output = None
    #gr.Markdown(title)
    #gr.Markdown(description)
    #gr.Markdown(article)

    with gr.Row():
        with gr.Column(scale=1):
            image_input = gr.Image(type="pil")
            with gr.Row():
                with gr.Column(scale=1):
                    chat_input = gr.Textbox(lines=1, label="VQA Input(问题输入)")
                    with gr.Row():
                        clear_button = gr.Button(value="Clear", interactive=True)
                        submit_button = gr.Button(
                            value="VLE", interactive=True, variant="primary"
                        )
                        '''
                    cap_submit_button = gr.Button(
                            value="Submit_CAP", interactive=True, variant="primary"
                        )
                    gpt3_submit_button = gr.Button(
                            value="Submit_GPT3", interactive=True, variant="primary"
                        )
                        '''
        with gr.Column():
            caption_output = gr.Textbox(lines=0, label="VQA ")
            gpt3_output_v1 = gr.Textbox(lines=0, label="VQA+LLM")
            caption_output_v1 = gr.Textbox(lines=0, label="CAP+LLM")
            
        image_input.change(
            lambda: ("", "", []),
            [],
            [ caption_output, state],
            queue=False,
        )
        chat_input.submit(
                    inference_chat,
                    [
                        image_input,
                        chat_input,
                    ],
                    [ caption_output],
                )
        clear_button.click(
                        lambda: ("", [], []),
                        [],
                        [chat_input,  state,caption_output,gpt3_output_v1,caption_output_v1],
                        queue=False,
                    )
        submit_button.click(
                        inference_chat,
                        [
                            image_input,
                            chat_input,
                        ],
                        [caption_output,gpt3_output_v1,caption_output_v1],
                    )
        '''
        cap_submit_button.click(
                        caption,
                        [
                            image_input,
                   
                        ],
                        [caption_output_v1],
                    )
        gpt3_submit_button.click(
                        gpt3,
                        [
                            chat_input,
                           caption_output ,
                            caption_output_v1,
                        ],
                        [gpt3_output_v1],
                    )
        '''

   # examples = gr.Examples(
   #     examples=examples,
   #     inputs=[image_input, chat_input],
  #  )

iface.queue(concurrency_count=1, api_open=False, max_size=10)
iface.launch(enable_queue=True)