import gradio as gr
import os
import requests
import json
import base64
from io import BytesIO
from huggingface_hub import login
from PIL import Image


# myip = os.environ["0.0.0.0"]
# myport = os.environ["80"]
myip = "146.152.224.103"
myport=8080

is_spaces = True if "SPACE_ID" in os.environ else False

is_shared_ui = False

from css_html_js import custom_css

from about import (
    CITATION_BUTTON_LABEL,
    CITATION_BUTTON_TEXT,
    EVALUATION_QUEUE_TEXT,
    INTRODUCTION_TEXT,
    LLM_BENCHMARKS_TEXT,
    TITLE,
)


def process_image_from_binary(img_stream):
    if img_stream is None:
        print("no image binary")
        return
    image_data = base64.b64decode(img_stream)
    image_bytes = BytesIO(image_data)
    img = Image.open(image_bytes)
    
    return img

def execute_prepare(diffusion_model_id, concept, steps, attack_id):
    print(f"my IP is {myip}, my port is {myport}")
    print(f"my input is diffusion_model_id: {diffusion_model_id}, concept: {concept}, steps: {steps}")
    response = requests.post('http://{}:{}/prepare'.format(myip, myport), 
                             json={"diffusion_model_id": diffusion_model_id, "concept": concept, "steps": steps, "attack_id": attack_id},
                             timeout=(10, 1200))
    print(f"result: {response}")
    # result = result.text[1:-1]
    prompt = ""
    img = None
    if response.status_code == 200:
        response_json = response.json()
        print(response_json)
        prompt = response_json['input_prompt']
        img = process_image_from_binary(response_json['no_attack_img'])
    else:
        print(f"Request failed with status code {response.status_code}")
    
    return prompt, img

def execute_udiff(diffusion_model_id, concept, steps, attack_id):
    print(f"my IP is {myip}, my port is {myport}")
    print(f"my input is diffusion_model_id: {diffusion_model_id}, concept: {concept}, steps: {steps}")
    response = requests.post('http://{}:{}/udiff'.format(myip, myport), 
                             json={"diffusion_model_id": diffusion_model_id, "concept": concept, "steps": steps, "attack_id": attack_id},
                             timeout=(10, 1200))
    print(f"result: {response}")
    # result = result.text[1:-1]
    prompt = ""
    img = None
    if response.status_code == 200:
        response_json = response.json()
        print(response_json)
        prompt = response_json['output_prompt']
        img = process_image_from_binary(response_json['attack_img'])
    else:
        print(f"Request failed with status code {response.status_code}")
    
    return prompt, img


css = '''
    .instruction{position: absolute; top: 0;right: 0;margin-top: 0px !important}
    .arrow{position: absolute;top: 0;right: -110px;margin-top: -8px !important}
    #component-4, #component-3, #component-10{min-height: 0}
    .duplicate-button img{margin: 0}
    #img_1, #img_2, #img_3, #img_4{height:15rem}
    #mdStyle{font-size: 0.7rem}
    #titleCenter {text-align:center}
'''


with gr.Blocks(css=custom_css) as demo:
    gr.HTML(TITLE)
    gr.Markdown(INTRODUCTION_TEXT, elem_classes="markdown-text")

#     gr.Markdown("# Demo of UnlearnDiffAtk.")
#     gr.Markdown("### UnlearnDiffAtk is an effective and efficient adversarial prompt generation approach for unlearned diffusion models(DMs).")
# #     gr.Markdown("####For more details, please visit the [project](https://www.optml-group.com/posts/mu_attack), 
# # check the [code](https://github.com/OPTML-Group/Diffusion-MU-Attack), and read the [paper](https://arxiv.org/abs/2310.11868).")
#     gr.Markdown("### Please notice that the process may take a long time, but the results will be saved. You can try it later if it waits for too long.")
    

    with gr.Row() as udiff:
        with gr.Row():
            drop = gr.Dropdown(["Object-Church", "Object-Parachute", "Object-Garbage_Truck","Style-VanGogh",
                               "Nudity"], 
                               label="Unlearning undesirable concepts")
        with gr.Column():
            # gr.Markdown("Please upload your model id.")
            drop_model = gr.Dropdown(["ESD", "FMN"], 
                               label="Unlearned DMs")
            # diffusion_model_T = gr.Textbox(label='diffusion_model_id')
            # concept = gr.Textbox(label='concept')
            # attacker = gr.Textbox(label='attacker')

            # start_button = gr.Button("Attack!")
        with gr.Column():
            atk_idx = gr.Textbox(label="attack index")

        with gr.Column():
             shown_columns_step = gr.Slider(
                            0, 100, value=40, 
                            step=1, label="Attack Steps", info="Choose between 0 and 100",
                            interactive=True,)
    with gr.Row() as attack:
        with gr.Column(min_width=512):
            start_button = gr.Button("Attack prepare!",size='lg')
            text_input = gr.Textbox(label="Input Prompt")
            
            orig_img = gr.Image(label="Image Generated by Input Prompt",width=512,show_share_button=False,show_download_button=False)
        with gr.Column():
            attack_button = gr.Button("UnlearnDiffAtk!",size='lg')
            text_ouput = gr.Textbox(label="Prompt Genetated by UnlearnDiffAtk")
            result_img = gr.Image(label="Image Gnerated by Prompt of UnlearnDiffAtk",width=512,show_share_button=False,show_download_button=False)
            
        start_button.click(fn=execute_prepare, inputs=[drop_model, drop, shown_columns_step, atk_idx], outputs=[text_input, orig_img], api_name="prepare")
        attack_button.click(fn=execute_udiff, inputs=[drop_model, drop, shown_columns_step, atk_idx], outputs=[text_ouput, result_img], api_name="udiff")


demo.queue().launch(server_name='0.0.0.0')