import concurrent.futures
import io
import os
import time

import numpy as np
import oss2
import requests
from PIL import Image, ImageDraw, ImageFont

from .log import logger

# oss
access_key_id = os.getenv("ACCESS_KEY_ID")
access_key_secret = os.getenv("ACCESS_KEY_SECRET")
bucket_name = os.getenv("BUCKET_NAME")
endpoint = os.getenv("ENDPOINT")

bucket = oss2.Bucket(oss2.Auth(access_key_id, access_key_secret), endpoint, bucket_name)
oss_path = "hejunjie.hjj/TransferAnythingHF"
oss_path_img_gallery = "hejunjie.hjj/TransferAnythingHF_img_gallery"

    
def download_img_pil(index, img_url):
    # print(img_url)
    r = requests.get(img_url, stream=True)
    if r.status_code == 200:
        img = Image.open(io.BytesIO(r.content))
        return (index, img)
    else:
        logger.error(f"Fail to download: {img_url}")


def download_images(img_urls, batch_size):
    imgs_pil = [None] * batch_size
    # worker_results = []
    with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
        to_do = []
        for i, url in enumerate(img_urls):
            future = executor.submit(download_img_pil, i, url)
            to_do.append(future)

        for future in concurrent.futures.as_completed(to_do):
            ret = future.result()
            # worker_results.append(ret)
            index, img_pil = ret
            imgs_pil[index] = img_pil  # 按顺序排列url,后续下载关联的图片或者svg需要使用

    return imgs_pil


def upload_np_2_oss(input_image, name="cache.png", gallery=False):
    assert name.lower().endswith((".png", ".jpg")), name
    imgByteArr = io.BytesIO()
    if name.lower().endswith(".png"):
        Image.fromarray(input_image).save(imgByteArr, format="PNG")
    else:
        Image.fromarray(input_image).save(imgByteArr, format="JPEG", quality=95)
    imgByteArr = imgByteArr.getvalue()

    if gallery:
        path = oss_path_img_gallery
    else:
        path = oss_path

    start_time = time.perf_counter()
    bucket.put_object(path + "/" + name, imgByteArr)  # data为数据,可以是图片
    ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24)  # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
    logger.info(f"upload cost: {time.perf_counter() - start_time} s.")
    del imgByteArr
    return ret


def upload_json_string_2_oss(jsonStr, name="cache.txt", gallery=False):
    if gallery:
        path = oss_path_img_gallery
    else:
        path = oss_path

    bucket.put_object(path + "/" + name, bytes(jsonStr, "utf-8"))  # data为数据
    ret = bucket.sign_url('GET', path + "/" + name, 60 * 60 * 24)  # 返回值为链接,参数依次为,方法/oss上文件路径/过期时间(s)
    return ret


def upload_preprocess(pil_base_image_rgba, pil_layout_image_dict, pil_style_image_dict, pil_color_image_dict,
                      pil_fg_mask):
    np_out_base_image = np_out_layout_image = np_out_style_image = np_out_color_image = None

    if pil_base_image_rgba is not None:
        np_fg_image = np.array(pil_base_image_rgba)[..., :3]
        np_fg_mask = np.expand_dims(np.array(pil_fg_mask).astype(float), axis=-1) / 255.
        np_fg_mask = np_fg_mask * 0.5 + 0.5
        np_out_base_image = (np_fg_image * np_fg_mask + (1 - np_fg_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                               255).astype(
            np.uint8)

    if pil_layout_image_dict is not None:
        np_layout_image = np.array(pil_layout_image_dict["image"].convert("RGBA"))
        np_layout_image, np_layout_alpha = np_layout_image[..., :3], np_layout_image[..., 3]
        np_layout_mask = np.array(pil_layout_image_dict["mask"].convert("L"))
        np_layout_mask = ((np_layout_alpha > 127) * (np_layout_mask < 127)).astype(float)[..., None]
        np_layout_mask = np_layout_mask * 0.5 + 0.5
        np_out_layout_image = (
                np_layout_image * np_layout_mask + (1 - np_layout_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                              255).astype(
            np.uint8)

    if pil_style_image_dict is not None:
        np_style_image = np.array(pil_style_image_dict["image"].convert("RGBA"))
        np_style_image, np_style_alpha = np_style_image[..., :3], np_style_image[..., 3]
        np_style_mask = np.array(pil_style_image_dict["mask"].convert("L"))
        np_style_mask = ((np_style_alpha > 127) * (np_style_mask < 127)).astype(float)[..., None]
        np_style_mask = np_style_mask * 0.5 + 0.5
        np_out_style_image = (
                np_style_image * np_style_mask + (1 - np_style_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                           255).astype(
            np.uint8)

    if pil_color_image_dict is not None:
        np_color_image = np.array(pil_color_image_dict["image"].convert("RGBA"))
        np_color_image, np_color_alpha = np_color_image[..., :3], np_color_image[..., 3]
        np_color_mask = np.array(pil_color_image_dict["mask"].convert("L"))
        np_color_mask = ((np_color_alpha > 127) * (np_color_mask < 127)).astype(float)[..., None]
        np_color_mask = np_color_mask * 0.5 + 0.5
        np_out_color_image = (
                np_color_image * np_color_mask + (1 - np_color_mask) * np.array([0, 0, 255])).round().clip(0,
                                                                                                           255).astype(
            np.uint8)

    return np_out_base_image, np_out_layout_image, np_out_style_image, np_out_color_image


def pad_image(image, target_size):
    iw, ih = image.size  # 原始图像的尺寸
    w, h = target_size  # 目标图像的尺寸
    scale = min(w / iw, h / ih)  # 转换的最小比例
    # 保证长或宽,至少一个符合目标图像的尺寸 0.5保证四舍五入
    nw = int(iw * scale + 0.5)
    nh = int(ih * scale + 0.5)
    image = image.resize((nw, nh), Image.BICUBIC)  # 更改图像尺寸,双立法插值效果很好
    new_image = Image.new('RGB', target_size, (255, 255, 255))  # 生成白色图像
    new_image.paste(image, ((w - nw) // 2, (h - nh) // 2))  # 将图像填充为中间图像,两侧为黑色的样式
    return new_image


def add_text(image, text):
    w, h = image.size
    text_image = image.copy()
    text_image_draw = ImageDraw.Draw(text_image)

    ttf = ImageFont.truetype("assets/ttf/AlibabaPuHuiTi-2-55-Regular.ttf", int(h / 10))
    left, top, right, bottom = ttf.getbbox(text)
    text_image_draw.rectangle((0, 0, right + left, bottom + top), fill=(255, 255, 255))

    image = Image.blend(image, text_image, 0.5)

    image_draw = ImageDraw.Draw(image)
    fillColor = (0, 0, 0, 255)  # 文字颜色:黑色
    pos = (0, 0)  # 文本左上角位置 (离左边界距离, 离上边界距离)
    image_draw.text(pos, text, font=ttf, fill=fillColor)
    return image.convert("RGB")


def compose_image(image_list, text_list, pil_size, nrow, ncol):
    w, h = pil_size  # 每张小图片大小

    if len(image_list) > nrow * ncol:
        raise ValueError("合成图片的参数和要求的数量不能匹配!")

    assert len(image_list) == len(text_list)
    new_image_list = []
    new_text_list = []
    for image, text in zip(image_list, text_list):
        if image is not None:
            new_image_list.append(image)
            new_text_list.append(text)
    if len(new_image_list) == 1:
        ncol = nrow = 1
    to_image = Image.new('RGB', (ncol * w, nrow * h), (255, 255, 255))  # 创建一个新图
    for y in range(1, nrow + 1):
        for x in range(1, ncol + 1):
            if ncol * (y - 1) + x - 1 < len(new_image_list):
                from_image = new_image_list[ncol * (y - 1) + x - 1].resize((w, h), Image.BICUBIC)
                from_text = new_text_list[ncol * (y - 1) + x - 1]
                if from_text is not None:
                    from_image = add_text(from_image, from_text)
                to_image.paste(from_image, ((x - 1) * w, (y - 1) * h))
    return to_image


def split_text_lines(text, max_w, ttf):
    text_split_lines = []
    text_h = 0
    if text != "":
        line_start = 0
        while line_start < len(text):
            line_count = 0
            _, _, right, bottom = ttf.getbbox(text[line_start: line_start + line_count + 1])
            while right < max_w and line_count < len(text):
                line_count += 1
                _, _, right, bottom = ttf.getbbox(text[line_start: line_start + line_count + 1])
            text_split_lines.append(text[line_start:line_start + line_count])
            text_h += bottom
            line_start += line_count
    return text_split_lines, text_h


def add_prompt(image, prompt, negative_prompt):
    if prompt == "" and negative_prompt == "":
        return image
    if prompt != "":
        prompt = "Prompt: " + prompt
    if negative_prompt != "":
        negative_prompt = "Negative prompt: " + negative_prompt

    w, h = image.size

    ttf = ImageFont.truetype("assets/ttf/AlibabaPuHuiTi-2-55-Regular.ttf", int(h / 20))

    prompt_split_lines, prompt_h = split_text_lines(prompt, w, ttf)
    negative_prompt_split_lines, negative_prompt_h = split_text_lines(negative_prompt, w, ttf)
    text_h = prompt_h + negative_prompt_h
    text = "\n".join(prompt_split_lines + negative_prompt_split_lines)
    text_image = Image.new(image.mode, (w, text_h), color=(255, 255, 255))
    text_image_draw = ImageDraw.Draw(text_image)
    text_image_draw.text((0, 0), text, font=ttf, fill=(0, 0, 0))

    out_image = Image.new(image.mode, (w, h + text_h), color=(255, 255, 255))
    out_image.paste(image, (0, 0))
    out_image.paste(text_image, (0, h))

    return out_image


def merge_images(np_fg_image, np_layout_image, np_style_image, np_color_image, np_res_image, prompt, negative_prompt):
    pil_res_image = Image.fromarray(np_res_image)

    w, h = pil_res_image.size
    pil_fg_image = None if np_fg_image is None else pad_image(Image.fromarray(np_fg_image), (w, h))
    pil_layout_image = None if np_layout_image is None else pad_image(Image.fromarray(np_layout_image), (w, h))
    pil_style_image = None if np_style_image is None else pad_image(Image.fromarray(np_style_image), (w, h))
    pil_color_image = None if np_color_image is None else pad_image(Image.fromarray(np_color_image), (w, h))

    input_images = [pil_layout_image, pil_style_image, pil_color_image, pil_fg_image]
    input_texts = ['Layout', 'Style', 'Color', 'Subject']
    input_compose_image = compose_image(input_images, input_texts, (w, h), nrow=2, ncol=2)
    input_compose_image = input_compose_image.resize((w, h), Image.BICUBIC)
    output_compose_image = compose_image([input_compose_image, pil_res_image], [None, None], (w, h), nrow=1,
                                         ncol=2)
    output_compose_image = add_prompt(output_compose_image, prompt, negative_prompt)

    output_compose_image = np.array(output_compose_image)

    return output_compose_image