import json
from pathlib import Path
from typing import Dict, Optional

import cv2
import psutil
from PIL import Image
from loguru import logger
from rich.console import Console
from rich.progress import (
    Progress,
    SpinnerColumn,
    TimeElapsedColumn,
    MofNCompleteColumn,
    TextColumn,
    BarColumn,
    TaskProgressColumn,
)

from iopaint.helper import pil_to_bytes
from iopaint.model.utils import torch_gc
from iopaint.model_manager import ModelManager
from iopaint.schema import InpaintRequest


def glob_images(path: Path) -> Dict[str, Path]:
    # png/jpg/jpeg
    if path.is_file():
        return {path.stem: path}
    elif path.is_dir():
        res = {}
        for it in path.glob("*.*"):
            if it.suffix.lower() in [".png", ".jpg", ".jpeg"]:
                res[it.stem] = it
        return res


def batch_inpaint(
    model: str,
    device,
    image: Path,
    mask: Path,
    output: Path,
    config: Optional[Path] = None,
    concat: bool = False,
):
    if image.is_dir() and output.is_file():
        logger.error(
            f"invalid --output: when image is a directory, output should be a directory"
        )
        exit(-1)
    output.mkdir(parents=True, exist_ok=True)

    image_paths = glob_images(image)
    mask_paths = glob_images(mask)
    if len(image_paths) == 0:
        logger.error(f"invalid --image: empty image folder")
        exit(-1)
    if len(mask_paths) == 0:
        logger.error(f"invalid --mask: empty mask folder")
        exit(-1)

    if config is None:
        inpaint_request = InpaintRequest()
        logger.info(f"Using default config: {inpaint_request}")
    else:
        with open(config, "r", encoding="utf-8") as f:
            inpaint_request = InpaintRequest(**json.load(f))

    model_manager = ModelManager(name=model, device=device)
    first_mask = list(mask_paths.values())[0]

    console = Console()

    with Progress(
        SpinnerColumn(),
        TextColumn("[progress.description]{task.description}"),
        BarColumn(),
        TaskProgressColumn(),
        MofNCompleteColumn(),
        TimeElapsedColumn(),
        console=console,
        transient=False,
    ) as progress:
        task = progress.add_task("Batch processing...", total=len(image_paths))
        for stem, image_p in image_paths.items():
            if stem not in mask_paths and mask.is_dir():
                progress.log(f"mask for {image_p} not found")
                progress.update(task, advance=1)
                continue
            mask_p = mask_paths.get(stem, first_mask)

            infos = Image.open(image_p).info

            img = cv2.imread(str(image_p))
            img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB)
            mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE)
            if mask_img.shape[:2] != img.shape[:2]:
                progress.log(
                    f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}"
                )
                mask_img = cv2.resize(
                    mask_img,
                    (img.shape[1], img.shape[0]),
                    interpolation=cv2.INTER_NEAREST,
                )
            mask_img[mask_img >= 127] = 255
            mask_img[mask_img < 127] = 0

            # bgr
            inpaint_result = model_manager(img, mask_img, inpaint_request)
            inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB)
            if concat:
                mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB)
                inpaint_result = cv2.hconcat([img, mask_img, inpaint_result])

            img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos)
            save_p = output / f"{stem}.png"
            with open(save_p, "wb") as fw:
                fw.write(img_bytes)

            progress.update(task, advance=1)
            torch_gc()
            # pid = psutil.Process().pid
            # memory_info = psutil.Process(pid).memory_info()
            # memory_in_mb = memory_info.rss / (1024 * 1024)
            # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB")