File size: 5,479 Bytes
9dce458 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 |
import torch
import numpy as np
import cv2
import os
import einops
import safetensors
import safetensors.torch
from PIL import Image
from omegaconf import OmegaConf
from .common import OfflineInpainter
from ..utils import resize_keep_aspect
from .booru_tagger import Tagger
from .sd_hack import hack_everything
from .ldm.util import instantiate_from_config
def get_state_dict(d):
return d.get('state_dict', d)
def load_state_dict(ckpt_path, location='cpu'):
_, extension = os.path.splitext(ckpt_path)
if extension.lower() == ".safetensors":
import safetensors.torch
state_dict = safetensors.torch.load_file(ckpt_path, device=location)
else:
state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
state_dict = get_state_dict(state_dict)
return state_dict
def create_model(config_path):
config = OmegaConf.load(config_path)
model = instantiate_from_config(config.model).cpu()
return model
def load_ldm_sd(model, path) :
if path.endswith('.safetensor') :
sd = safetensors.torch.load_file(path)
else :
sd = load_state_dict(path)
model.load_state_dict(sd, strict = False)
class StableDiffusionInpainter(OfflineInpainter):
_MODEL_MAPPING = {
'model_grapefruit': {
'url': 'https://civitai.com/api/download/models/8364',
'hash': 'dd680bd77d553e095faf58ff8c12584efe2a9b844e18bcc6ba2a366b85caceb8',
'file': 'abyssorangemix2_Hard-inpainting.safetensors',
},
'model_wd_swinv2': {
'url': 'https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.onnx',
'hash': '67740df7ede9a53e50d6e29c6a5c0d6c862f1876c22545d810515bad3ae17bb1',
'file': 'wd_swinv2.onnx',
},
'model_wd_swinv2_csv': {
'url': 'https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/raw/main/selected_tags.csv',
'hash': '8c8750600db36233a1b274ac88bd46289e588b338218c2e4c62bbc9f2b516368',
'file': 'selected_tags.csv',
}
}
def __init__(self, *args, **kwargs):
os.makedirs(self.model_dir, exist_ok=True)
super().__init__(*args, **kwargs)
async def _load(self, device: str):
self.tagger = Tagger(self._get_file_path('wd_swinv2.onnx'))
self.model = create_model('manga_translator/inpainting/guided_ldm_inpaint9_v15.yaml').cuda()
load_ldm_sd(self.model, self._get_file_path('abyssorangemix2_Hard-inpainting.safetensors'))
hack_everything()
self.model.eval()
self.device = device
self.model = self.model.to(device)
async def _unload(self):
del self.model
@torch.no_grad()
async def _infer(self, image: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, verbose: bool = False) -> np.ndarray:
img_original = np.copy(image)
mask_original = np.copy(mask)
mask_original[mask_original < 127] = 0
mask_original[mask_original >= 127] = 1
mask_original = mask_original[:, :, None]
height, width, c = image.shape
if max(image.shape[0: 2]) > inpainting_size:
image = resize_keep_aspect(image, inpainting_size)
mask = resize_keep_aspect(mask, inpainting_size)
pad_size = 64
h, w, c = image.shape
if h % pad_size != 0:
new_h = (pad_size - (h % pad_size)) + h
else:
new_h = h
if w % pad_size != 0:
new_w = (pad_size - (w % pad_size)) + w
else:
new_w = w
if new_h != h or new_w != w:
image = cv2.resize(image, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
self.logger.info(f'Inpainting resolution: {new_w}x{new_h}')
tags = self.tagger.label_cv2_bgr(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
self.logger.info(f'tags={list(tags.keys())}')
blacklist = set()
pos_prompt = ','.join([x for x in tags.keys() if x not in blacklist]).replace('_', ' ')
pos_prompt = 'masterpiece,best quality,' + pos_prompt
neg_prompt = 'worst quality, low quality, normal quality,text,text,text,text'
if self.device.startswith('cuda') :
with torch.autocast(enabled = True, device_type = 'cuda') :
img = self.model.img2img_inpaint(
image = Image.fromarray(image),
c_text = pos_prompt,
uc_text = neg_prompt,
mask = Image.fromarray(mask),
device = self.device
)
else :
img = self.model.img2img_inpaint(
image = Image.fromarray(image),
c_text = pos_prompt,
uc_text = neg_prompt,
mask = Image.fromarray(mask),
device = self.device
)
img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
if new_h != height or new_w != width:
img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR)
ans = img_inpainted * mask_original + img_original * (1 - mask_original)
return ans
|