#!/usr/bin/env python

from __future__ import annotations

import os
import random
from typing import Tuple, Optional

import gradio as gr
from huggingface_hub import HfApi

from inf import InferencePipeline


SAMPLE_MODEL_IDS = [
    'lora-library/B-LoRA-teddybear',
    'lora-library/B-LoRA-bull',
    'lora-library/B-LoRA-wolf_plushie',
    'lora-library/B-LoRA-pen_sketch',
    'lora-library/B-LoRA-cartoon_line',
    'lora-library/B-LoRA-child',
    'lora-library/B-LoRA-vase',
    'lora-library/B-LoRA-scary_mug',
    'lora-library/B-LoRA-statue',
    'lora-library/B-LoRA-colorful_teapot',
    'lora-library/B-LoRA-grey_sloth_plushie',
    'lora-library/B-LoRA-teapot',
    'lora-library/B-LoRA-backpack_dog',
    'lora-library/B-LoRA-buddha',
    'lora-library/B-LoRA-dog6',
    'lora-library/B-LoRA-poop_emoji',
    'lora-library/B-LoRA-pot',
    'lora-library/B-LoRA-fat_bird',
    'lora-library/B-LoRA-elephant',
    'lora-library/B-LoRA-metal_bird',
    'lora-library/B-LoRA-cat',
    'lora-library/B-LoRA-dog2',
    'lora-library/B-LoRA-drawing1',
    'lora-library/B-LoRA-village_oil',
    'lora-library/B-LoRA-watercolor',
    'lora-library/B-LoRA-house_3d',
    'lora-library/B-LoRA-ink_sketch',
    'lora-library/B-LoRA-drawing3',
    'lora-library/B-LoRA-crayon_drawing',
    'lora-library/B-LoRA-kiss',
    'lora-library/B-LoRA-drawing4',
    'lora-library/B-LoRA-working_cartoon',
    'lora-library/B-LoRA-painting',
    'lora-library/B-LoRA-drawing2'
    'lora-library/B-LoRA-multi-dog2',
]
css = """
.gradio-container {
    max-width: 900px !important;
}

#title {
    text-align: center;
}

#title h1 {
    font-size: 250%;
}
    
.lora-title {
            background-image: linear-gradient(to right, #314755 0%, #26a0da  51%, #314755  100%);
            text-align: center;        
            border-radius: 10px;
            display: block;
          }

.lora-title h2 {
    color: white !important;
}

.gr-image {
    width: 256px;
    height: 256px;
    object-fit: contain;
    margin: auto;
}

.res-image {
    object-fit: contain;
    margin: auto;
}

.lora-column {
    border: none;
    background: none;
}
.gr-row {
    align-items: center;
    justify-content: center;
    margin-top: 5px;
}

.svelte-iyf88w {
    background: none;
}
"""

def get_choices(hf_token):
    api = HfApi(token=hf_token)
    choices = [
        info.modelId for info in api.list_models(author='lora-library')
    ]
    models_list = ['None'] + SAMPLE_MODEL_IDS + choices
    return models_list


def get_image_from_card(card, model_id) -> Optional[str]:
    try:
        card_path = f"https://huggingface.co/{model_id}/resolve/main/"
        widget = card.data.get('widget')
        if widget is not None or len(widget) > 0:
            output = widget[0].get('output')
            if output is not None:
                url = output.get('url')
                if url is not None:
                    return card_path + url
        return None
    except Exception:
        return None


def demo_init():
    try:
        choices = get_choices(app.hf_token)
        content_blora = random.choice(SAMPLE_MODEL_IDS)
        style_blora = random.choice(SAMPLE_MODEL_IDS)
        content_blora_prompt, content_blora_image = app.load_model_info(content_blora)
        style_blora_prompt, style_blora_image = app.load_model_info(style_blora)

        content_lora_model_id = gr.update(choices=choices, value=content_blora)
        content_prompt = gr.update(value=content_blora_prompt)
        content_image = gr.update(value=content_blora_image)

        style_lora_model_id = gr.update(choices=choices, value=style_blora)
        style_prompt = gr.update(value=style_blora_prompt)
        style_image = gr.update(value=style_blora_image)

        prompt = gr.update(
            value=f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style')

        return content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt, style_image, prompt

    except Exception as e:
        raise type(e)(f'failed to demo_init, due to: {e}')


def toggle_column(is_checked):
    try:
        return 'None' if is_checked else random.choice(SAMPLE_MODEL_IDS)
    except Exception as e:
        raise type(e)(f'failed to toggle_column, due to: {e}')

def handle_prompt_change(content_blora_prompt, style_blora_prompt) -> str:
    try:
        if content_blora_prompt and style_blora_prompt:
            return f'{content_blora_prompt} in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style'
        if content_blora_prompt:
            return content_blora_prompt
        if style_blora_prompt:
            return f'A dog in {style_blora_prompt[0].lower() + style_blora_prompt[1:]} style'

        return ''
    except Exception as e:
        raise type(e)(f'failed to handle_prompt_change, due to: {e}')


class InferenceUtil:
    def __init__(self, hf_token: str | None):
        self.hf_token = hf_token

    def load_model_info(self, lora_model_id: str) -> Tuple[str, Optional[str]]:
        try:
            try:
                card = InferencePipeline.get_model_card(lora_model_id,
                                                        self.hf_token)
            except Exception:
                return '', None
            instance_prompt = getattr(card.data, 'instance_prompt', '')
            image_url = get_image_from_card(card, lora_model_id)
            return instance_prompt, image_url
        except Exception as e:
            raise type(e)(f'failed to load_model_info, due to: {e}')

    def update_model_info(self, model_source: str):
        try:
            if model_source == 'None':
                return '', None
            else:
                model_info = self.load_model_info(model_source)
                new_prompt, new_image = model_info[0], model_info[1]
            return new_prompt, new_image
        except Exception as e:
            raise type(e)(f'failed to update_model_info, due to: {e}')


hf_token = os.getenv('HF_TOKEN')
pipe = InferencePipeline(hf_token)
app = InferenceUtil(hf_token)


with gr.Blocks(css=css) as demo:
    title = gr.HTML(
        '''<h1>Implicit Style-Content Separation using B-LoRA</h1>
        <p>This is a demo for our <a href="https://arxiv.org/abs/2403.14572">paper</a>: <b>''Implicit Style-Content Separation using B-LoRA''</b>.
    <br>
    Project page and code is available <a href="https://b-lora.github.io/B-LoRA/">here</a>.</p>
    Select your favorite style and content components from the list. (prefixed with <strong>`B-LoRA-`<strong>)
        ''',
        elem_id="title"
      )
    with gr.Row(elem_classes="gr-row"):
        with gr.Column():
            with gr.Group(elem_classes="lora-column"):
                content_sub_title = gr.HTML('''<h2>Content B-LoRA</h2>''', elem_classes="lora-title")
                content_checkbox = gr.Checkbox(label='Use Content Only', value=False)
                content_lora_model_id = gr.Dropdown(label='Model ID', choices=[])
                content_prompt = gr.Text(label='Content instance prompt', interactive=False, max_lines=1)
                content_image = gr.Image(label='Content Image', elem_classes="gr-image")
        with gr.Column():
            with gr.Group(elem_classes="lora-column"):
                style_sub_title = gr.HTML('''<h2>Style B-LoRA</h2>''', elem_classes="lora-title")
                style_checkbox = gr.Checkbox(label='Use Style Only', value=False)
                style_lora_model_id = gr.Dropdown(label='Model ID', choices=[])
                style_prompt = gr.Text(label='Style instance prompt', interactive=False, max_lines=1)
                style_image = gr.Image(label='Style Image', elem_classes="gr-image")
    with gr.Row(elem_classes="gr-row"):
        with gr.Column():
            with gr.Group():
                prompt = gr.Textbox(
                    label='Prompt',
                    max_lines=1,
                    placeholder='Example: "A [c] in [s] style"'
                )
                result = gr.Gallery(label='Result', elem_classes="res-image")
                with gr.Accordion('Other Parameters', open=False, elem_classes="gr-accordion"):
                    content_alpha = gr.Slider(label='Content B-LoRA alpha',
                                              minimum=0,
                                              maximum=2,
                                              step=0.05,
                                              value=1)
                    style_alpha = gr.Slider(label='Style B-LoRA alpha',
                                            minimum=0,
                                            maximum=2,
                                            step=0.05,
                                            value=1)
                    seed = gr.Slider(label='Seed',
                                     minimum=0,
                                     maximum=100000,
                                     step=1,
                                     value=8888)
                    num_steps = gr.Slider(label='Number of Steps',
                                          minimum=0,
                                          maximum=100,
                                          step=1,
                                          value=40)
                    guidance_scale = gr.Slider(label='CFG Scale',
                                               minimum=0,
                                               maximum=50,
                                               step=0.1,
                                               value=7.5)
                    num_images_per_prompt = gr.Slider(label='Number of Images per Prompt',
                                                      minimum=1,
                                                      maximum=4,
                                                      step=1,
                                                      value=2)
                run_button = gr.Button('Generate')
    demo.load(demo_init, inputs=[],
              outputs=[content_lora_model_id, content_prompt, content_image, style_lora_model_id, style_prompt,
                       style_image, prompt], queue=False, show_progress="hidden")
    content_lora_model_id.change(
        fn=app.update_model_info,
        inputs=content_lora_model_id,
        outputs=[
            content_prompt,
            content_image,
        ])
    style_lora_model_id.change(
        fn=app.update_model_info,
        inputs=style_lora_model_id,
        outputs=[
            style_prompt,
            style_image,
        ])
    style_prompt.change(
        fn=handle_prompt_change,
        inputs=[content_prompt, style_prompt],
        outputs=prompt,
    )
    content_prompt.change(
        fn=handle_prompt_change,
        inputs=[content_prompt, style_prompt],
        outputs=prompt,
    )
    content_checkbox.change(toggle_column, inputs=[content_checkbox],
                            outputs=[style_lora_model_id])
    style_checkbox.change(toggle_column, inputs=[style_checkbox],
                          outputs=[content_lora_model_id])
    inputs = [
        content_lora_model_id,
        style_lora_model_id,
        prompt,
        content_alpha,
        style_alpha,
        seed,
        num_steps,
        guidance_scale,
        num_images_per_prompt
    ]
    prompt.submit(fn=pipe.run, inputs=inputs, outputs=result)
    run_button.click(fn=pipe.run, inputs=inputs, outputs=result)

demo.queue(max_size=10).launch(share=False)