# utils.py
from enum import Enum, auto

import torch
from huggingface_hub import hf_hub_download
from PIL import Image, ImageEnhance, ImageFilter
import cv2
import numpy as np
from refiners.fluxion.utils import load_from_safetensors, tensor_to_image
from refiners.foundationals.clip import CLIPTextEncoderL
from refiners.foundationals.latent_diffusion import SD1UNet
from refiners.foundationals.latent_diffusion.stable_diffusion_1 import SD1Autoencoder
from refiners.foundationals.latent_diffusion.stable_diffusion_1.ic_light import ICLight


def load_ic_light(device: torch.device, dtype: torch.dtype) -> ICLight:
    return ICLight(
        patch_weights=load_from_safetensors(
            path=hf_hub_download(
                repo_id="refiners/sd15.ic_light.fc",
                filename="model.safetensors",
                revision="ea10b4403e97c786a98afdcbdf0e0fec794ea542",
            ),
        ),
        unet=SD1UNet(in_channels=4, device=device, dtype=dtype).load_from_safetensors(
            tensors_path=hf_hub_download(
                repo_id="refiners/sd15.realistic_vision.v5_1.unet",
                filename="model.safetensors",
                revision="94f74be7adfd27bee330ea1071481c0254c29989",
            )
        ),
        clip_text_encoder=CLIPTextEncoderL(device=device, dtype=dtype).load_from_safetensors(
            tensors_path=hf_hub_download(
                repo_id="refiners/sd15.realistic_vision.v5_1.text_encoder",
                filename="model.safetensors",
                revision="7f6fa1e870c8f197d34488e14b89e63fb8d7fd6e",
            )
        ),
        lda=SD1Autoencoder(device=device, dtype=dtype).load_from_safetensors(
            tensors_path=hf_hub_download(
                repo_id="refiners/sd15.realistic_vision.v5_1.autoencoder",
                filename="model.safetensors",
                revision="99f089787a6e1a852a0992da1e286a19fcbbaa50",
            )
        ),
        device=device,
        dtype=dtype,
    )


def resize_modulo_8(
    image: Image.Image,
    size: int = 768,
    resample: Image.Resampling | None = None,
    on_short: bool = True,
) -> Image.Image:
    """이미지 크기를 8의 배수로 조정"""
    assert size % 8 == 0, "Size must be a multiple of 8 because this is the latent compression size."
    side_size = min(image.size) if on_short else max(image.size)
    scale = size / (side_size * 8)
    new_size = (int(image.width * scale) * 8, int(image.height * scale) * 8)
    return image.resize(new_size, resample=resample or Image.Resampling.LANCZOS)


def adjust_image(
    image: Image.Image,
    brightness=0.0,
    contrast=0.0,
    temperature=0.0,
    saturation=0.0,
    tint=0.0,
    blur_intensity=0,
    exposure=0.0,
    vibrance=0.0,
    color_mixer_blues=0.0,
) -> Image.Image:
    """이미지 조정 함수"""
    image = image.convert('RGB')

    # 노출 조정 (Exposure)
    if exposure != 0.0:
        # Exposure ranges from -5 to 5, where 0 is neutral
        exposure_factor = 1 + (exposure / 5.0)
        exposure_factor = max(exposure_factor, 0.01)  # Prevent zero or negative
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(exposure_factor)

    # 밝기 조정
    if brightness != 0.0:
        # Brightness ranges from -5 to 5, mapped to brightness factor
        brightness_factor = 1 + (brightness / 5.0)
        brightness_factor = max(brightness_factor, 0.01)  # Prevent zero or negative
        enhancer = ImageEnhance.Brightness(image)
        image = enhancer.enhance(brightness_factor)

    # 대비 조정
    if contrast != 0.0:
        # Contrast ranges from -100 to 100, mapped to contrast factor
        contrast_factor = 1 + (contrast / 100.0)
        contrast_factor = max(contrast_factor, 0.01)  # Prevent zero or negative
        enhancer = ImageEnhance.Contrast(image)
        image = enhancer.enhance(contrast_factor)

    # 채도 조정 (Vibrance)
    if vibrance != 0.0:
        # Vibrance simulates adjusting the saturation; positive increases saturation, negative decreases
        vibrance_factor = 1 + (vibrance / 100.0)
        vibrance_factor = max(vibrance_factor, 0.0)  # Prevent negative saturation
        enhancer = ImageEnhance.Color(image)
        image = enhancer.enhance(vibrance_factor)

    # 채도 조정 (Saturation)
    if saturation != 0.0:
        # Saturation ranges from -100 to 100, mapped to saturation factor
        saturation_factor = 1 + (saturation / 100.0)
        saturation_factor = max(saturation_factor, 0.0)  # Prevent negative saturation
        enhancer = ImageEnhance.Color(image)
        image = enhancer.enhance(saturation_factor)

    # 색온도 조정
    if temperature != 0.0:
        # To prevent division by zero, adjust temperature calculation
        temp_factor = 1 + (temperature / 100.0)
        temp_factor = max(temp_factor, 0.01)  # Prevent zero or negative

        r, g, b = image.split()
        r = r.point(lambda i: i * temp_factor)
        b = b.point(lambda i: i / temp_factor)
        image = Image.merge('RGB', (r, g, b))

    # 색조 조정 (Tint)
    if tint != 0.0:
        image_np = np.array(image)
        image_hsv = cv2.cvtColor(image_np, cv2.COLOR_RGB2HSV).astype(np.float32)
        image_hsv[:, :, 0] = (image_hsv[:, :, 0] + tint) % 180
        image_hsv[:, :, 0] = np.clip(image_hsv[:, :, 0], 0, 179)
        image_rgb = cv2.cvtColor(image_hsv.astype(np.uint8), cv2.COLOR_HSV2RGB)
        image = Image.fromarray(image_rgb)

    # 블러 적용
    if blur_intensity > 0:
        image = image.filter(ImageFilter.GaussianBlur(radius=blur_intensity))

    # Color Mixer (Blues)
    if color_mixer_blues != 0.0:
        image_np = np.array(image).astype(np.float32)
        # Adjust the blue channel
        image_np[:, :, 2] = np.clip(image_np[:, :, 2] + (color_mixer_blues / 100.0) * 255, 0, 255)
        image = Image.fromarray(image_np.astype(np.uint8))

    return image


class LightingPreference(str, Enum):
    LEFT = auto()
    RIGHT = auto()
    TOP = auto()
    BOTTOM = auto()
    NONE = auto()

    def get_init_image(self, width: int, height: int, interval: tuple[float, float] = (0.0, 1.0)) -> Image.Image | None:
        """조명 선호도에 따른 그라데이션 이미지 생성"""
        start, end = interval
        match self:
            case LightingPreference.LEFT:
                tensor = torch.linspace(end, start, width).repeat(1, 1, height, 1)
            case LightingPreference.RIGHT:
                tensor = torch.linspace(start, end, width).repeat(1, 1, height, 1)
            case LightingPreference.TOP:
                tensor = torch.linspace(end, start, height).repeat(1, 1, width, 1).transpose(2, 3)
            case LightingPreference.BOTTOM:
                tensor = torch.linspace(start, end, height).repeat(1, 1, width, 1).transpose(2, 3)
            case LightingPreference.NONE:
                return None

        return tensor_to_image(tensor).convert("RGB")

    @classmethod
    def from_str(cls, value: str):
        match value.lower():
            case "left":
                return LightingPreference.LEFT
            case "right":
                return LightingPreference.RIGHT
            case "top":
                return LightingPreference.TOP
            case "bottom":
                return LightingPreference.BOTTOM
            case "none":
                return LightingPreference.NONE
            case _:
                raise ValueError(f"Invalid lighting preference: {value}")