import os
import sys
sys.path.append(os.getcwd())
sys.path.append(os.path.join(os.getcwd(), "annotator/entityseg"))
import cv2
import spaces
import einops
import torch
import gradio as gr
import numpy as np
from pytorch_lightning import seed_everything
from PIL import Image

from annotator.util import resize_image, HWC3
from annotator.canny import CannyDetector
from annotator.midas import MidasDetector
from annotator.entityseg import EntitysegDetector
from annotator.openpose import OpenposeDetector
from annotator.content import ContentDetector
from annotator.cielab import CIELabDetector

from models.util import create_model, load_state_dict
from models.ddim_hacked import DDIMSampler

'''
define conditions
'''
max_conditions = 8
condition_types = ["edge", "depth", "seg", "pose", "content", "color"]

apply_canny = CannyDetector()
apply_midas = MidasDetector()
apply_seg = EntitysegDetector()
apply_openpose = OpenposeDetector()
apply_content = ContentDetector()
apply_color = CIELabDetector()

processors = {
    "edge": apply_canny,
    "depth": apply_midas,
    "seg": apply_seg,
    "pose": apply_openpose,
    "content": apply_content,
    "color": apply_color,
}

descriptors = {
    "edge": "canny",
    "depth": "depth",
    "seg": "segmentation",
    "pose": "openpose",
}


@torch.no_grad()
def get_unconditional_global(c_global):
    if isinstance(c_global, dict):
        return {k:torch.zeros_like(v) for k,v in c_global.items()}
    elif isinstance(c_global, list):
        return [torch.zeros_like(c) for c in c_global]
    else:
        return torch.zeros_like(c_global) 


@spaces.GPU(duration=120)
def process(prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, 
            strength, scale, seed, eta, global_strength, color_strength, local_strength, *args):
    
    seed_everything(seed)

    conds_and_types = args
    conds = conds_and_types[0::2]
    types = conds_and_types[1::2]
    conds = [c for c in conds if c is not None]
    types = [t for t in types if t is not None]
    assert len(conds) == len(types)

    detected_maps = []
    other_maps = []
    tasks = []

    # initialize global control
    global_conditions = dict(clipembedding=np.zeros((1, 768), dtype=np.float32), color=np.zeros((1, 180), dtype=np.float32))
    global_control = {}
    for key in global_conditions.keys():
        global_cond = torch.from_numpy(global_conditions[key]).unsqueeze(0).repeat(num_samples, 1, 1)
        global_cond = global_cond.cuda().to(memory_format=torch.contiguous_format).float()
        global_control[key] = global_cond

    # initialize local control
    anchor_image = HWC3(np.zeros((image_resolution, image_resolution, 3)).astype(np.uint8))
    oH, oW = anchor_image.shape[:2]
    H, W, C = resize_image(anchor_image, image_resolution).shape
    anchor_tensor = ddim_sampler.model.qformer_vis_processor['eval'](Image.fromarray(anchor_image))
    local_control = torch.tensor(anchor_tensor).cuda().to(memory_format=torch.contiguous_format).half()

    task_prompt = ''

    with torch.no_grad():

        # set up local control
        for cond, typ in zip(conds, types):
            if typ in ['edge', 'depth', 'seg', 'pose']:
                oH, oW = cond.shape[:2] 
                cond_image = HWC3(cv2.resize(cond, (W, H)))
                cond_detected_map = processors[typ](cond_image)
                cond_detected_map = HWC3(cond_detected_map)
                detected_maps.append(cond_detected_map)
                tasks.append(descriptors[typ])
            elif typ in ['content']:
                other_maps.append(cond)
                content_image = cv2.cvtColor(cond, cv2.COLOR_RGB2BGR)
                content_emb = apply_content(content_image)
                global_conditions['clipembedding'] = content_emb
            elif typ in ['color']:
                color_hist = apply_color(cond)
                global_conditions['color'] = color_hist
                color_palette = apply_color.hist_to_palette(color_hist) # (50, 189, 3)
                color_palette = cv2.resize(color_palette, (W, H), cv2.INTER_NEAREST)
                other_maps.append(color_palette)
        if len(detected_maps) > 0:
            local_control = torch.cat([ddim_sampler.model.qformer_vis_processor['eval'](Image.fromarray(img)).cuda().unsqueeze(0) for img in detected_maps], dim=1)
            task_prompt = ' conditioned on ' + ' and '.join(tasks)
        local_control = local_control.repeat(num_samples, 1, 1, 1)

        # set up global control
        for key in global_conditions.keys():
            global_cond = torch.from_numpy(global_conditions[key]).unsqueeze(0).repeat(num_samples, 1, 1)
            global_cond = global_cond.cuda().to(memory_format=torch.contiguous_format).float()
            global_control[key] = global_cond

        # set up prompt
        input_prompt = (prompt + ' ' + task_prompt).strip()

        # set up cfg
        uc_local_control = local_control
        uc_global_control = {k:torch.zeros_like(v) for k,v in global_control.items()}

        cond = {
            "local_control": [local_control], 
            "c_crossattn": [model.get_learned_conditioning([prompt + ', ' + a_prompt] * num_samples)], 
            "global_control": [global_control],
            "text": [[input_prompt] * num_samples],
        }
        un_cond = {
            "local_control": [uc_local_control], 
            "c_crossattn": [model.get_learned_conditioning([n_prompt] * num_samples)], 
            'global_control': [uc_global_control],
            "text": [[input_prompt] * num_samples],
        }
        shape = (4, H // 8, W // 8)

        model.control_scales = [strength] * 13
        samples, _ = ddim_sampler.sample(ddim_steps, num_samples,
                                                     shape, cond, verbose=False, eta=eta,
                                                     unconditional_guidance_scale=scale,
                                                     unconditional_conditioning=un_cond, 
                                                     global_strength=global_strength, 
                                                     color_strength=color_strength,
                                                     local_strength=local_strength)

        x_samples = model.decode_first_stage(samples)
        x_samples = (einops.rearrange(x_samples, 'b c h w -> b h w c') * 127.5 + 127.5).cpu().numpy().clip(0, 255).astype(np.uint8)
        results = [x_samples[i] for i in range(num_samples)]

    results = [cv2.resize(res, (oW, oH)) for res in results]
    detected_maps = [cv2.resize(maps, (oW, oH)) for maps in detected_maps]
    return [results, detected_maps+other_maps]

  
def variable_image_outputs(k):
    if k is None:
        k = 1
    k = int(k)
    imageboxes = []
    for i in range(max_conditions):
        if i<k:
            with gr.Row(visible=True):
                img = gr.Image(sources=['upload'], type="numpy", label=f'Condition {i+1}', visible=True, interactive=True, scale=3, height=200)
                typ = gr.Dropdown(condition_types, visible=True, interactive=True, label="type", scale=1)
        else:
            with gr.Row(visible=False):
                img = gr.Image(sources=['upload'], type="numpy", label=f'Condition {i+1}', visible=False, scale=3, height=200)
                typ = gr.Dropdown(condition_types, visible=False, interactive=True, label="type", scale=1)
        imageboxes.append(img)
        imageboxes.append(typ)
    return imageboxes


'''
define model
'''
config_file = "configs/anycontrol.yaml"
model_file = "ckpts/anycontrol_15.ckpt"
model = create_model(config_file).cpu()
model.load_state_dict(load_state_dict(model_file, location='cpu'))
model = model.cuda()
ddim_sampler = DDIMSampler(model)



block = gr.Blocks(theme='bethecloud/storj_theme').queue()
with block:
    with gr.Row():
        gr.Markdown("## AnyControl Demo")
        gr.Markdown("---")
    with gr.Row():
        with gr.Column(scale=1):
            with gr.Blocks():
                s = gr.Slider(1, max_conditions, value=1, step=1, label="How many conditions to upload:")
                imageboxes = []
                for i in range(max_conditions):
                    if i==0:
                        with gr.Row():
                            img = gr.Image(visible=True, sources=['upload'], type="numpy", label='Condition 1', interactive=True, scale=3, height=200)
                            typ = gr.Dropdown(condition_types, visible=True, interactive=True, label="type", scale=1)
                    else:
                        with gr.Row():
                            img = gr.Image(visible=False, sources=['upload'], type="numpy", label=f'Condition {i+1}', scale=3, height=200)
                            typ = gr.Dropdown(condition_types, visible=False, interactive=True, label="type", scale=1)
                    imageboxes.append(img)
                    imageboxes.append(typ)
                s.change(variable_image_outputs, s, imageboxes)
        with gr.Column(scale=2):
            with gr.Row():
                prompt = gr.Textbox(label="Prompt")
            with gr.Row():
                with gr.Column():
                    with gr.Accordion("Advanced options", open=False):
                        num_samples = gr.Slider(label="Images", minimum=1, maximum=12, value=4, step=1)
                        image_resolution = gr.Slider(label="Image Resolution", minimum=256, maximum=768, value=512, step=64)
                        strength = gr.Slider(label="Control Strength", minimum=0.0, maximum=2.0, value=1, step=0.01)

                        local_strength = gr.Slider(label="Local Strength", minimum=0, maximum=2, value=1, step=0.01)
                        global_strength = gr.Slider(label="Global Strength", minimum=0, maximum=2, value=1, step=0.01)
                        color_strength = gr.Slider(label="Color Strength", minimum=0, maximum=2, value=1, step=0.01)
                        
                        ddim_steps = gr.Slider(label="Steps", minimum=1, maximum=100, value=50, step=1)
                        scale = gr.Slider(label="Guidance Scale", minimum=0.1, maximum=30.0, value=7.5, step=0.1)
                        seed = gr.Slider(label="Seed", minimum=-1, maximum=2147483647, value=42, step=1)
                        eta = gr.Number(label="Eta (DDIM)", value=0.0)
                        
                        a_prompt = gr.Textbox(label="Added Prompt", value='best quality, extremely detailed')
                        n_prompt = gr.Textbox(label="Negative Prompt",
                                              value='longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality')


            with gr.Row():
                run_button = gr.Button(value="Run")
            with gr.Row():
                image_gallery = gr.Gallery(label='Generation', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto', interactive=False)
            with gr.Row():
                cond_gallery = gr.Gallery(label='Condition', show_label=True, elem_id="gallery", columns=[4], rows=[1], height='auto', interactive=False)

            inputs = [prompt, a_prompt, n_prompt, num_samples, image_resolution, ddim_steps, 
                      strength, scale, seed, eta, local_strength, global_strength, color_strength] + imageboxes
            run_button.click(fn=process, inputs=inputs, outputs=[image_gallery, cond_gallery])


# uncomment this block in case you need it
# os.environ['http_proxy'] = ''
# os.environ['https_proxy'] = ''
# os.environ['no_proxy'] = 'localhost,127.0.0.0/8,127.0.1.1'
# os.environ['HTTP_PROXY'] = ''
# os.environ['HTTPS_PROXY'] = ''
# os.environ['NO_PROXY'] = 'localhost,127.0.0.0/8,127.0.1.1'
# os.environ['TMPDIR'] = './tmpfiles'


block.launch(server_name='0.0.0.0', allowed_paths=["."], share=False)