File size: 5,479 Bytes
9dce458
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import torch
import numpy as np
import cv2
import os
import einops
import safetensors
import safetensors.torch
from PIL import Image
from omegaconf import OmegaConf

from .common import OfflineInpainter
from ..utils import resize_keep_aspect

from .booru_tagger import Tagger
from .sd_hack import hack_everything
from .ldm.util import instantiate_from_config


def get_state_dict(d):
    return d.get('state_dict', d)

def load_state_dict(ckpt_path, location='cpu'):
    _, extension = os.path.splitext(ckpt_path)
    if extension.lower() == ".safetensors":
        import safetensors.torch
        state_dict = safetensors.torch.load_file(ckpt_path, device=location)
    else:
        state_dict = get_state_dict(torch.load(ckpt_path, map_location=torch.device(location)))
    state_dict = get_state_dict(state_dict)
    return state_dict


def create_model(config_path):
    config = OmegaConf.load(config_path)
    model = instantiate_from_config(config.model).cpu()
    return model


def load_ldm_sd(model, path) :
    if path.endswith('.safetensor') :
        sd = safetensors.torch.load_file(path)
    else :
        sd = load_state_dict(path)
    model.load_state_dict(sd, strict = False)

class StableDiffusionInpainter(OfflineInpainter):
    _MODEL_MAPPING = {
        'model_grapefruit': {
            'url': 'https://civitai.com/api/download/models/8364',
            'hash': 'dd680bd77d553e095faf58ff8c12584efe2a9b844e18bcc6ba2a366b85caceb8',
            'file': 'abyssorangemix2_Hard-inpainting.safetensors',
        },
        'model_wd_swinv2': {
            'url': 'https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/resolve/main/model.onnx',
            'hash': '67740df7ede9a53e50d6e29c6a5c0d6c862f1876c22545d810515bad3ae17bb1',
            'file': 'wd_swinv2.onnx',
        },
        'model_wd_swinv2_csv': {
            'url': 'https://huggingface.co/SmilingWolf/wd-v1-4-swinv2-tagger-v2/raw/main/selected_tags.csv',
            'hash': '8c8750600db36233a1b274ac88bd46289e588b338218c2e4c62bbc9f2b516368',
            'file': 'selected_tags.csv',
        }
    }

    def __init__(self, *args, **kwargs):
        os.makedirs(self.model_dir, exist_ok=True)
        super().__init__(*args, **kwargs)

    async def _load(self, device: str):
        self.tagger = Tagger(self._get_file_path('wd_swinv2.onnx'))
        self.model = create_model('manga_translator/inpainting/guided_ldm_inpaint9_v15.yaml').cuda()
        load_ldm_sd(self.model, self._get_file_path('abyssorangemix2_Hard-inpainting.safetensors'))
        hack_everything()
        self.model.eval()
        self.device = device
        self.model = self.model.to(device)

    async def _unload(self):
        del self.model

    @torch.no_grad()
    async def _infer(self, image: np.ndarray, mask: np.ndarray, inpainting_size: int = 1024, verbose: bool = False) -> np.ndarray:
        img_original = np.copy(image)
        mask_original = np.copy(mask)
        mask_original[mask_original < 127] = 0
        mask_original[mask_original >= 127] = 1
        mask_original = mask_original[:, :, None]

        height, width, c = image.shape
        if max(image.shape[0: 2]) > inpainting_size:
            image = resize_keep_aspect(image, inpainting_size)
            mask = resize_keep_aspect(mask, inpainting_size)
        pad_size = 64
        h, w, c = image.shape
        if h % pad_size != 0:
            new_h = (pad_size - (h % pad_size)) + h
        else:
            new_h = h
        if w % pad_size != 0:
            new_w = (pad_size - (w % pad_size)) + w
        else:
            new_w = w
        if new_h != h or new_w != w:
            image = cv2.resize(image, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
            mask = cv2.resize(mask, (new_w, new_h), interpolation = cv2.INTER_LINEAR)
        self.logger.info(f'Inpainting resolution: {new_w}x{new_h}')
        tags = self.tagger.label_cv2_bgr(cv2.cvtColor(image, cv2.COLOR_RGB2BGR))
        self.logger.info(f'tags={list(tags.keys())}')
        blacklist = set()
        pos_prompt = ','.join([x for x in tags.keys() if x not in blacklist]).replace('_', ' ')
        pos_prompt = 'masterpiece,best quality,' + pos_prompt
        neg_prompt = 'worst quality, low quality, normal quality,text,text,text,text'
        if self.device.startswith('cuda') :
            with torch.autocast(enabled = True, device_type = 'cuda') :
                img = self.model.img2img_inpaint(
                    image = Image.fromarray(image),
                    c_text = pos_prompt,
                    uc_text = neg_prompt,
                    mask = Image.fromarray(mask),
                    device = self.device
                    )
        else :
            img = self.model.img2img_inpaint(
                image = Image.fromarray(image),
                c_text = pos_prompt,
                uc_text = neg_prompt,
                mask = Image.fromarray(mask),
                device = self.device
                )

        img_inpainted = (einops.rearrange(img, '1 c h w -> h w c').cpu().numpy() * 127.5 + 127.5).astype(np.uint8)
        if new_h != height or new_w != width:
            img_inpainted = cv2.resize(img_inpainted, (width, height), interpolation = cv2.INTER_LINEAR)
        ans = img_inpainted * mask_original + img_original * (1 - mask_original)
        return ans