Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| from typing import Dict, Optional | |
| import cv2 | |
| import psutil | |
| from PIL import Image | |
| from loguru import logger | |
| from rich.console import Console | |
| from rich.progress import ( | |
| Progress, | |
| SpinnerColumn, | |
| TimeElapsedColumn, | |
| MofNCompleteColumn, | |
| TextColumn, | |
| BarColumn, | |
| TaskProgressColumn, | |
| ) | |
| from .helper import pil_to_bytes | |
| from .model.utils import torch_gc | |
| from .model_manager import ModelManager | |
| from .schema import InpaintRequest | |
| def glob_images(path: Path) -> Dict[str, Path]: | |
| # png/jpg/jpeg | |
| if path.is_file(): | |
| return {path.stem: path} | |
| elif path.is_dir(): | |
| res = {} | |
| for it in path.glob("*.*"): | |
| if it.suffix.lower() in [".png", ".jpg", ".jpeg"]: | |
| res[it.stem] = it | |
| return res | |
| def batch_inpaint( | |
| model: str, | |
| device, | |
| image: Path, | |
| mask: Path, | |
| output: Path, | |
| config: Optional[Path] = None, | |
| concat: bool = False, | |
| ): | |
| if image.is_dir() and output.is_file(): | |
| logger.error( | |
| f"invalid --output: when image is a directory, output should be a directory" | |
| ) | |
| exit(-1) | |
| output.mkdir(parents=True, exist_ok=True) | |
| image_paths = glob_images(image) | |
| mask_paths = glob_images(mask) | |
| if len(image_paths) == 0: | |
| logger.error(f"invalid --image: empty image folder") | |
| exit(-1) | |
| if len(mask_paths) == 0: | |
| logger.error(f"invalid --mask: empty mask folder") | |
| exit(-1) | |
| if config is None: | |
| inpaint_request = InpaintRequest() | |
| # logger.info(f"Using default config: {inpaint_request}") | |
| else: | |
| with open(config, "r", encoding="utf-8") as f: | |
| inpaint_request = InpaintRequest(**json.load(f)) | |
| model_manager = ModelManager(name=model, device=device) | |
| first_mask = list(mask_paths.values())[0] | |
| console = Console() | |
| with Progress( | |
| SpinnerColumn(), | |
| TextColumn("[progress.description]{task.description}"), | |
| BarColumn(), | |
| TaskProgressColumn(), | |
| MofNCompleteColumn(), | |
| TimeElapsedColumn(), | |
| console=console, | |
| transient=False, | |
| ) as progress: | |
| task = progress.add_task("Batch processing...", total=len(image_paths)) | |
| for stem, image_p in image_paths.items(): | |
| if stem not in mask_paths and mask.is_dir(): | |
| progress.log(f"mask for {image_p} not found") | |
| progress.update(task, advance=1) | |
| continue | |
| mask_p = mask_paths.get(stem, first_mask) | |
| infos = Image.open(image_p).info | |
| img = cv2.imread(str(image_p)) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGB) | |
| mask_img = cv2.imread(str(mask_p), cv2.IMREAD_GRAYSCALE) | |
| if mask_img.shape[:2] != img.shape[:2]: | |
| progress.log( | |
| f"resize mask {mask_p.name} to image {image_p.name} size: {img.shape[:2]}" | |
| ) | |
| mask_img = cv2.resize( | |
| mask_img, | |
| (img.shape[1], img.shape[0]), | |
| interpolation=cv2.INTER_NEAREST, | |
| ) | |
| mask_img[mask_img >= 127] = 255 | |
| mask_img[mask_img < 127] = 0 | |
| # bgr | |
| inpaint_result = model_manager(img, mask_img, inpaint_request) | |
| inpaint_result = cv2.cvtColor(inpaint_result, cv2.COLOR_BGR2RGB) | |
| if concat: | |
| mask_img = cv2.cvtColor(mask_img, cv2.COLOR_GRAY2RGB) | |
| inpaint_result = cv2.hconcat([img, mask_img, inpaint_result]) | |
| img_bytes = pil_to_bytes(Image.fromarray(inpaint_result), "png", 100, infos) | |
| save_p = output / f"{stem}.png" | |
| with open(save_p, "wb") as fw: | |
| fw.write(img_bytes) | |
| progress.update(task, advance=1) | |
| torch_gc() | |
| # pid = psutil.Process().pid | |
| # memory_info = psutil.Process(pid).memory_info() | |
| # memory_in_mb = memory_info.rss / (1024 * 1024) | |
| # print(f"原图大小:{img.shape},当前进程的内存占用:{memory_in_mb}MB") | |