import torch import numpy as np import cv2 import os import einops import safetensors import safetensors.torch from PIL import Image from omegaconf import OmegaConf from .common import OfflineInpainter from ..utils import resize_keep_aspect from .booru_tagger import Tagger from .sd_hack import hack_everything from .ldm.util import instantiate_from_config def get_state_dict(d): return d.get('state_dict', d) def load_state_dict(ckpt_path, location='cpu'): _, extension = os.path.splitext(ckpt_path) if extension.lower() == ".safetensors": import safetensors.torch state_dict = safetensors.torch.load_file(ckpt_path, device=location) else: state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location))) state_dict = get_state_dict(state_dict) return state_dict def create_model(config_path): config = OmegaConf.load(config_path) model = instantiate_from_config(config.model).cpu() return model def load_ldm_sd(model, path) : if path.endswith('.safetensor') : sd = safetensors.torch.load_file(path) else : sd = load_state_dict(path) model.load_state_dict(sd, strict = False) class StableDiffusionInpainter(OfflineInpainter): _MODEL_MAPPING = { 'model_grapefruit': { 'url': 'https://civitai.com/api/download/models/8364', 'hash': 'dd680bd77d553e095faf58ff8c12584efe2a9b844e18bcc6ba2a366b85caceb8', 'file': 'abyssorangemix2_Hard-inpainting.safetensors', }, 'model_wd_swinv2': { 'url': 'https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.onnx', 'hash': '67740df7ede9a53e50d6e29c6a5c0d6c862f1876c22545d810515bad3ae17bb1', 'file': 'wd_swinv2.onnx', }, 'model_wd_swinv2_csv': { 'url': 'https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/raw/main/selected_tags.csv', 'hash': '8c8750600db36233a1b274ac88bd46289e588b338218c2e4c62bbc9f2b516368', 'file': 'selected_tags.csv', } } def __init__(self, *args, **kwargs): os.makedirs(self.model_dir, exist_ok=True) super().__init__(*args, **kwargs) async def _load(self, device: str): self.tagger = Tagger(self._get_file_path('wd_swinv2.onnx')) self.model = create_model('manga_translator/inpainting/guided_ldm_inpaint9_v15.yaml').cuda() load_ldm_sd(self.model, self._get_file_path('abyssorangemix2_Hard-inpainting.safetensors')) hack_everything() self.model.eval() self.device = device self.model = self.model.to(device) async def _unload(self): del self.model @torch.no_grad() async def _infer(self, image: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, verbose: bool = False) -> np.ndarray: img_original = np.copy(image) mask_original = np.copy(mask) mask_original[mask_original < 127] = 0 mask_original[mask_original >= 127] = 1 mask_original = mask_original[:, :, None] height, width, c = image.shape if max(image.shape[0: 2]) > inpainting_size: image = resize_keep_aspect(image, inpainting_size) mask = resize_keep_aspect(mask, inpainting_size) pad_size = 64 h, w, c = image.shape if h % pad_size != 0: new_h = (pad_size - (h % pad_size)) + h else: new_h = h if w % pad_size != 0: new_w = (pad_size - (w % pad_size)) + w else: new_w = w if new_h != h or new_w != w: image = cv2.resize(image, (new_w, new_h), interpolation = cv2.INTER_LINEAR) mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR) self.logger.info(f'Inpainting resolution: {new_w}x{new_h}') tags = self.tagger.label_cv2_bgr(cv2.cvtColor(image, cv2.COLOR_RGB2BGR)) self.logger.info(f'tags={list(tags.keys())}') blacklist = set() pos_prompt = ','.join([x for x in tags.keys() if x not in blacklist]).replace('_', ' ') pos_prompt = 'masterpiece,best quality,' + pos_prompt neg_prompt = 'worst quality, low quality, normal quality,text,text,text,text' if self.device.startswith('cuda') : with torch.autocast(enabled = True, device_type = 'cuda') : img = self.model.img2img_inpaint( image = Image.fromarray(image), c_text = pos_prompt, uc_text = neg_prompt, mask = Image.fromarray(mask), device = self.device ) else : img = self.model.img2img_inpaint( image = Image.fromarray(image), c_text = pos_prompt, uc_text = neg_prompt, mask = Image.fromarray(mask), device = self.device ) img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8) if new_h != height or new_w != width: img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR) ans = img_inpainted * mask_original + img_original * (1 - mask_original) return ans