import json
import os
import cv2
import random
import numpy as np
import gradio as gr
import torch

from zhipuai import ZhipuAI
from pytorch_lightning import seed_everything
from pprint import pprint
from PIL import Image, ImageDraw, ImageFont
from diffusers import (
    ControlNetModel,
    StableDiffusionControlNetPipeline,
)
from diffusers import (
    DDIMScheduler,
    PNDMScheduler,
    EulerAncestralDiscreteScheduler,
    DPMSolverMultistepScheduler,
    EulerDiscreteScheduler,
    LMSDiscreteScheduler,
    HeunDiscreteScheduler
)
from controlnet_aux import (
    PidiNetDetector,
    HEDdetector
)


BBOX_MAX_NUM = 8
BBOX_INI_NUM = 0
MAX_LENGTH = 20
device = 'cuda'
pipeline = None
pre_pipeline = None
model_root = os.getenv('REPO_ROOT')
scheduler_root = f'{model_root}/Scheduler'
model_list =[
    'JoyType.v1.0', 'RevAnimated-animation-动漫', 'GhostMix-animation-动漫',
    'rpg.v5-fantasy_realism-奇幻写实', 'midjourneyPapercut-origami-折纸版画',
    'dvarchExterior-architecture-建筑', 'awpainting.v13-portrait-人物肖像'
]
chn_example_dict = {
    '漂亮的风景照,很多山峰,清澈的湖水': 'beautiful landscape, many peaks, clear lake',
    '画有玫瑰的卡片,明亮的背景': 'a card with roses, bright background',
    '一张关于健康教育的卡片,上面有一些文字,有一些食物图标,背景里有一些水果喝饮料的图标,且背景是模糊的': \
        'a card for health education, with some writings on it, '
        'food icons on the card, some fruits and drinking in the background, blur background '
}
match_dict = {
    'JoyType.v1.0': 'JoyType-v1-1M',
    'RevAnimated-animation-动漫': 'rev-animated-v1-2-2',
    'GhostMix-animation-动漫': 'GhostMix_V2.0',
    'rpg.v5-fantasy_realism-奇幻写实': 'rpg_v5',
    'midjourneyPapercut-origami-折纸版画': 'midjourneyPapercut_v1',
    'dvarchExterior-architecture-建筑': 'dvarchExterior',
    'awpainting.v13-portrait-人物肖像': 'awpainting_v13'
}
font_list = [
    'CHN-华文行楷',
    'CHN-华文新魏',
    'CHN-清松手写体',
    'CHN-巴蜀墨迹',
    'CHN-雷盖体',
    'CHN-演示夏行楷',
    'CHN-鸿雷板书简体',
    'CHN-斑马字类',
    'CHN-青柳隶书',
    'CHN-辰宇落雁体',
    'CHN-宅家麦克笔',
    'ENG-Playwrite',
    'ENG-Okesip',
    'ENG-Shrikhand',
    'ENG-Nextstep',
    'ENG-Filthyrich',
    'ENG-BebasNeue',
    'ENG-Gloock',
    'ENG-Lemon',
    'RUS-Automatons',
    'RUS-MKyrill',
    'RUS-Alice',
    'RUS-Caveat',
    'KOR-ChosunGs',
    'KOR-Dongle',
    'KOR-GodoMaum',
    'KOR-UnDotum',
    'JPN-GlTsukiji',
    'JPN-Aoyagireisyosimo',
    'JPN-KouzanMouhitu',
    'JPN-Otomanopee'
]


def change_settings(base_model):
    if base_model == model_list[0]:
        return gr.update(value=20), gr.update(value=7.5), gr.update(value='PNDM')
    elif base_model == model_list[1]:
        return gr.update(value=30), gr.update(value=8.5), gr.update(value='Euler')
    elif base_model == model_list[2]:
        return gr.update(value=32), gr.update(value=8.5), gr.update(value='Euler')
    elif base_model == model_list[3]:
        return gr.update(value=20), gr.update(value=7.5), gr.update(value='DPM')
    elif base_model == model_list[4]:
        return gr.update(value=25), gr.update(value=6.5), gr.update(value='Euler')
    elif base_model == model_list[5]:
        return gr.update(value=25), gr.update(value=8.5), gr.update(value='Euler')
    elif base_model == model_list[6]:
        return gr.update(value=25), gr.update(value=7), gr.update(value='DPM')
    else:
        pass


def update_box_num(choice):
    update_list_1 = []  # checkbox
    update_list_2 = []  # font
    update_list_3 = []  # text
    update_list_4 = []  # bounding box
    for i in range(BBOX_MAX_NUM):
        if i < choice:
            update_list_1.append(gr.update(value=True))
            update_list_2.append(gr.update(visible=True))
            update_list_3.append(gr.update(visible=True))
            update_list_4.extend([gr.update(visible=False) for _ in range(4)])
        else:
            update_list_1.append(gr.update(value=False))
            update_list_2.append(gr.update(visible=False, value='CHN-华文行楷'))
            update_list_3.append(gr.update(visible=False, value=''))
            update_list_4.extend([
                gr.update(visible=False, value=0.4),
                gr.update(visible=False, value=0.4),
                gr.update(visible=False, value=0.2),
                gr.update(visible=False, value=0.2)
            ])

    return *update_list_1, *update_list_2, *update_list_3, *update_list_4


def load_box_list(example_id, choice):
    with open(f'templates/{example_id}.json', 'r') as f:
        info = json.load(f)
    update_list1 = []
    update_list2 = []
    update_list3 = []
    update_list4 = []

    for i in range(BBOX_MAX_NUM):
        visible = info['visible'][i]
        pos = info['pos'][i * 4: (i + 1) * 4]
        update_list1.append(gr.update(value=visible))
        update_list2.append(gr.update(value=info['font'][i], visible=visible))
        update_list3.append(gr.update(value=info['text'][i], visible=visible))
        update_list4.extend([
            gr.update(value=pos[0]),
            gr.update(value=pos[1]),
            gr.update(value=pos[2]),
            gr.update(value=pos[3])
        ])

    return *update_list1, *update_list2, \
        *update_list3, *update_list4, gr.update(value=-1)

    
def re_edit():
    global BBOX_MAX_NUM
    update_list = []
    for i in range(BBOX_MAX_NUM):
        update_list.extend([gr.update(value=0.4), gr.update(value=0.4), gr.update(value=0.2),
                            gr.update(value=0.2)])
    return *update_list, \
        gr.Image(
            value=create_canvas(),
            label='Rect Position', elem_id='MD-bbox-rect-t2i',
            show_label=False, visible=True
        ), \
        gr.Slider(value=512), gr.Slider(value=512)


def resize_w(w, img):
    return cv2.resize(img, (w, img.shape[0]))


def resize_h(h, img):
    return cv2.resize(img, (img.shape[1], h))


def create_canvas(w=512, h=512, c=3, line=5):
    image = np.full((h, w, c), 200, dtype=np.uint8)
    for i in range(h):
        if i % (w // line) == 0:
            image[i, :, :] = 150
    for j in range(w):
        if j % (w // line) == 0:
            image[:, j, :] = 150
    image[h // 2 - 8:h // 2 + 8, w // 2 - 8:w // 2 + 8, :] = [200, 0, 0]
    return image


def canny(img):
    low_threshold = 64
    high_threshold = 100

    img = cv2.Canny(img, low_threshold, high_threshold)
    img = img[:, :, None]
    img = np.concatenate([img, img, img], axis=2)
    return Image.fromarray(img)


def judge_overlap(coord_list1, coord_list2):
    judge = coord_list1[0] < coord_list2[2] and coord_list1[2] > coord_list2[0] \
            and coord_list1[1] < coord_list2[3] and coord_list1[3] > coord_list2[1]
    return judge


def parse_render_list(box_list, shape, box_num):
    width = shape[0]
    height = shape[1]
    polygons = []
    font_names = []
    texts = []
    valid_list = box_list[:box_num]
    pos_list = box_list[box_num: 5 * box_num]
    font_name_list = box_list[5 * box_num: 6 * box_num]
    text_list = box_list[6 * box_num: 7 * box_num]
    empty_flag = False

    print(font_name_list, text_list)

    for i, valid in enumerate(valid_list):
        if valid:
            pos = pos_list[i * 4: (i + 1) * 4]
            top_left_x = int(pos[0] * width)
            top_left_y = int(pos[1] * height)
            w = int(pos[2] * width)
            h = int(pos[3] * height)
            font_name = str(font_name_list[i])
            text = str(text_list[i])
            if text == '':
                empty_flag = True
                text = 'JoyType'
            if w <= 0 or h <= 0:
                gr.Warning(f'Area of the box{i + 1} cannot be zero!')
                return [], False
            polygon = [
                top_left_x,
                top_left_y,
                w, h
            ]
            try:
                assert font_name in font_list
                font_name = font_name.split('-')[-1]
            except Exception as e:
                gr.Warning('Please choose a correct font!')
                return [], False

            polygons.append(polygon)
            font_names.append(font_name.split('-')[-1])
            texts.append(text)

    if empty_flag:
        gr.Warning('Null strings will be filled automatically!')

    for i in range(len(polygons)):
        for j in range(i + 1, len(polygons)):
            if judge_overlap(
                    [polygons[i][0], polygons[i][1], polygons[i][0] + polygons[i][2], polygons[i][1] + polygons[i][3]],
                    [polygons[j][0], polygons[j][1], polygons[j][0] + polygons[j][2], polygons[j][1] + polygons[j][3]]
            ):
                gr.Warning('Find overlapping boxes!')
                return [], False

    render_list = []
    for i in range(len(polygons)):
        text_dict = {}
        text_dict['text'] = texts[i]
        text_dict['polygon'] = polygons[i]
        text_dict['font_name'] = font_names[i]
        render_list.append(text_dict)

    return render_list, True


def render_all_text(render_list, shape, threshold=512):
    width = shape[0]
    height = shape[1]
    board = Image.new('RGB', (width, height), 'black')

    for text_dict in render_list:
        text = text_dict['text']
        polygon = text_dict['polygon']
        font_name = text_dict['font_name']
        if len(text) > MAX_LENGTH:
            text = text[:MAX_LENGTH]
            gr.Warning(f'{text}... exceeds the maximum length {MAX_LENGTH} and has been cropped.')

        w, h = polygon[2:]
        vert = True if w < h else False
        image4ratio = Image.new('RGB', (1024, 1024), 'black')
        draw = ImageDraw.Draw(image4ratio)

        try:
            font = ImageFont.truetype(f'./font/{font_name}.ttf', encoding='utf-8', size=50)
        except FileNotFoundError:
            font = ImageFont.truetype(f'./font/{font_name}.otf', encoding='utf-8', size=50)

        if not vert:
            draw.text(xy=(0, 0), text=text, font=font, fill='white')
            _, _, _tw, _th = draw.textbbox(xy=(0, 0), text=text, font=font)
            _th += 1
        else:
            _tw, y_c = 0, 0
            for c in text:
                draw.text(xy=(0, y_c), text=c, font=font, fill='white')
                _l, _t, _r, _b = font.getbbox(c)
                _tw = max(_tw, _r - _l)
                y_c += _b
            _th = y_c + 1

        ratio = (_th * w) / (_tw * h)
        text_img = image4ratio.crop((0, 0, _tw, _th))
        x_offset, y_offset = 0, 0
        if 0.8 <= ratio <= 1.2:
            text_img = text_img.resize((w, h))
        elif ratio < 0.75:
            resize_h = int(_th * (w / _tw))
            text_img = text_img.resize((w, resize_h))
            y_offset = (h - resize_h) // 2
        else:
            resize_w = int(_tw * (h / _th))
            text_img = text_img.resize((resize_w, h))
            x_offset = (w - resize_w) // 2

        board.paste(text_img, (polygon[0] + x_offset, polygon[1] + y_offset))

    return board


def load_pipeline(model_name, scheduler_name):
    controlnet_path = os.path.join(model_root, f'{match_dict["JoyType.v1.0"]}')
    model_path = os.path.join(model_root, model_name)
    scheduler_name = scheduler_name.lower()

    if scheduler_name == 'pndm':
        scheduler = PNDMScheduler.from_pretrained(scheduler_root, subfolder='pndm')
    if scheduler_name == 'lms':
        scheduler = LMSDiscreteScheduler.from_pretrained(scheduler_root, subfolder='lms')
    if scheduler_name == 'euler':
        scheduler = EulerDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler')
    if scheduler_name == 'dpm':
        scheduler = DPMSolverMultistepScheduler.from_pretrained(scheduler_root, subfolder='dpm')
    if scheduler_name == 'ddim':
        scheduler = DDIMScheduler.from_pretrained(scheduler_root, subfolder='ddim')
    if scheduler_name == 'heun':
        scheduler = HeunDiscreteScheduler.from_pretrained(scheduler_root, subfolder='heun')
    if scheduler_name == 'euler-ancestral':
        scheduler = EulerAncestralDiscreteScheduler.from_pretrained(scheduler_root, subfolder='euler-ancestral')

    controlnet = ControlNetModel.from_pretrained(
        controlnet_path,
        subfolder='controlnet',
        torch_dtype=torch.float32
    )
    pipeline = StableDiffusionControlNetPipeline.from_pretrained(
        model_path,
        scheduler=scheduler,
        controlnet=controlnet,
        torch_dtype=torch.float32,
    ).to(device)

    return pipeline


def preprocess_prompt(prompt):
    client = ZhipuAI(api_key=os.getenv('ZHIPU_API_KEY'))
    response = client.chat.completions.create(
        model="glm-4-0520",
        messages=[
            {
                'role': 'system',
                'content': '''
                   Stable Diffusion是一款利用深度学习的文生图模型,支持通过使用提示词来产生新的图像,描述要包含或省略的元素。
                   我在这里引入Stable Diffusion算法中的Prompt概念,又被称为提示符。这里的Prompt通常可以用来描述图像,
                   他由普通常见的单词构成,最好是可以在数据集来源站点找到的著名标签(比如Ddanbooru)。
                   下面我将说明Prompt的生出步骤,这里的Prompt主要用于描述人物。在Prompt的生成中,你需要通过提示词来描述 人物属性,主题,外表,情绪,衣服,姿势,视角,动作,背景。
                   用英语单词或短语甚至自然语言的标签来描述,并不局限于我给你的单词。然后将你想要的相似的提示词组合在一起,请使用英文半角,做分隔符,每个提示词不要带引号,并将这些按从最重要到最不重要的顺序 排列。
                   另外请您注意,永远在每个 Prompt的前面加上引号里的内容, 
                   “(((best quality))),(((ultra detailed))),(((masterpiece))),illustration,” 这是高质量的标志。
                   人物属性中,1girl表示你生成了一个女孩,2girls表示生成了两个女孩,一次。另外再注意,Prompt中不能带有-和_。
                   可以有空格和自然语言,但不要太多,单词不能重复。只返回Prompt。
                   '''
            },
            {
                'role': 'user', 
                'content': prompt
            }
        ],
        temperature=0.5,
        max_tokens=2048,
        top_p=1,
        stream=False,
    )

    if response:
        glm = []
        glm_return_list = response.choices
        for item in glm_return_list:
            glm.append(item.message.content)

        return {'flag': 1, 'data': glm}
    else:
        return {'flag': 0, 'data': {}}


def process(
        num_samples,
        a_prompt,
        n_prompt,
        conditioning_scale,
        cfg_scale,
        inference_steps,
        seed,
        usr_prompt,
        rect_img,
        base_model,
        scheduler_name,
        box_num,
        *box_list
):
    if usr_prompt == '':
        gr.Warning('Must input a prompt!')
        return None, gr.Markdown('error')

    if seed == -1:
        seed = random.randint(0, 2147483647)
    seed_everything(seed)
    
    # Support Chinese Input
    if usr_prompt in chn_example_dict.keys():
        usr_prompt = chn_example_dict[usr_prompt]
    else:
        for ch in usr_prompt:
            if '\u4e00' <= ch <= '\u9fff':
                data = preprocess_prompt(usr_prompt)
                if data['flag'] == 1:
                    usr_prompt = data['data'][0][1: -1]
                else:
                    gr.Warning('Something went wrong while translating your prompt, please try again.')
                    return None, gr.Markdown('error')
                break

    shape = (rect_img.shape[1], rect_img.shape[0])
    render_list, flag = parse_render_list(box_list, shape, box_num)
    if flag:
        render_img = render_all_text(render_list, shape)
    else:
        return None, gr.Markdown('error')

    model_name = match_dict[base_model]
    render_img = canny(np.array(render_img))
    
    w, h = render_img.size
    global pipeline, pre_pipeline
    
    if pre_pipeline != model_name or pipeline is None:
        pre_pipeline = model_name
        pipeline = load_pipeline(model_name, scheduler_name)
    
    batch_render_img = [render_img for _ in range(num_samples)]
    batch_prompt = [f'{usr_prompt}, {a_prompt}' for _ in range(num_samples)]
    batch_n_prompt = [n_prompt for _ in range(num_samples)]
    
    images = pipeline(
        batch_prompt,
        negative_prompt=batch_n_prompt,
        image=batch_render_img,
        controlnet_conditioning_scale=float(conditioning_scale),
        guidance_scale=float(cfg_scale),
        width=w,
        height=h,
        num_inference_steps=int(inference_steps),
    ).images
    
    return images, gr.Markdown(f'{seed},  {usr_prompt},  {box_list}')


def draw_example(box_list, color, id):
    board = Image.fromarray(create_canvas())

    w, h = board.size

    draw = ImageDraw.Draw(board, mode='RGBA')
    visible = box_list[:BBOX_MAX_NUM]
    pos = box_list[BBOX_MAX_NUM: 5 * BBOX_MAX_NUM]
    font = box_list[5 * BBOX_MAX_NUM: 6 * BBOX_MAX_NUM]
    text = box_list[6 * BBOX_MAX_NUM:]

    info = {
        'visible': list(visible),
        'pos': list(pos),
        'font': list(font),
        'text': list(text)
    }

    with open(f'templates/{id}.json', 'w') as f:
        json.dump(info, f)

    for i in range(BBOX_MAX_NUM):
        if visible[i] is True:
            polygon = pos[i * 4: (i + 1) * 4]
            print(polygon)
            left = w * polygon[0]
            top = h * polygon[1]
            right = left + w * polygon[2]
            bottom = top + h * polygon[3]
            draw.rectangle([left, top, right, bottom], outline=color[i][0], fill=color[i][1], width=3)

    board.save(f'./examples/{id}.png')


if __name__ == '__main__':
    pass