kairusann commited on
Commit
6ee3369
·
1 Parent(s): 7a207c9

first working version

Browse files
Files changed (6) hide show
  1. app.py +101 -0
  2. models/netG.pth +3 -0
  3. models/sk_model.pth +3 -0
  4. requirements.txt +6 -0
  5. sketch_models.py +191 -0
  6. utils.py +83 -0
app.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import functools
3
+ import cv2
4
+ import gradio as gr
5
+ import torch
6
+ import torch.nn as nn
7
+ import numpy as np
8
+ from PIL import Image
9
+ from einops import rearrange
10
+ from sketch_models import SimpleGenerator, UnetGenerator
11
+ from utils import common_input_validate, resize_image_with_pad, HWC3
12
+
13
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
14
+
15
+ def get_sketch(input_image, mode='anime', detect_resolution=512, output_type="pil", upscale_method="INTER_LANCZOS4", **kwargs):
16
+
17
+ input_image, output_type = common_input_validate(input_image, output_type, **kwargs)
18
+ detected_map, remove_pad = resize_image_with_pad(input_image, detect_resolution, upscale_method)
19
+
20
+ H, W, C = input_image.shape
21
+ Hn = 256 * int(np.ceil(float(H) / 256.0))
22
+ Wn = 256 * int(np.ceil(float(W) / 256.0))
23
+
24
+ assert detected_map.ndim == 3
25
+
26
+ if mode == 'realistic':
27
+ model = SimpleGenerator(3,1,3).to(device)
28
+ model.load_state_dict(torch.load("models/sk_model.pth", map_location=device))
29
+ model.eval()
30
+
31
+ with torch.no_grad():
32
+ image = torch.from_numpy(detected_map).float().to(device)
33
+ image = image / 255.0
34
+ image = rearrange(image, 'h w c -> 1 c h w')
35
+
36
+ line = model(image)[0][0]
37
+ line = line.cpu().numpy()
38
+ line = (line * 255.0).clip(0, 255).astype(np.uint8)
39
+
40
+ detected_map = HWC3(line)
41
+
42
+ detected_map = remove_pad(detected_map)
43
+ detected_map = cv2.resize(detected_map, (W, H), interpolation=cv2.INTER_LANCZOS4)
44
+
45
+ elif mode == 'anime':
46
+ norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
47
+ model = UnetGenerator(3, 1, 8, 64, norm_layer=norm_layer, use_dropout=False).to(device)
48
+ ckpt = torch.load("models/netG.pth", map_location=device)
49
+ for key in list(ckpt.keys()):
50
+ if 'module.' in key:
51
+ ckpt[key.replace('module.', '')] = ckpt[key]
52
+ del ckpt[key]
53
+ model.load_state_dict(ckpt)
54
+ model.eval()
55
+
56
+ input_image = cv2.resize(input_image, (Wn, Hn), interpolation=cv2.INTER_LANCZOS4)
57
+
58
+ with torch.no_grad():
59
+ image_feed = torch.from_numpy(input_image).float().to(device)
60
+ image_feed = image_feed / 127.5 - 1.0
61
+ image_feed = rearrange(image_feed, 'h w c -> 1 c h w')
62
+
63
+ line = model(image_feed)[0, 0] * 127.5 + 127.5
64
+ line = line.cpu().numpy()
65
+ line = line.clip(0, 255).astype(np.uint8)
66
+
67
+ #A1111 uses INTER AREA for downscaling so ig that is the best choice
68
+ detected_map = remove_pad(255 - detected_map)
69
+ detected_map = cv2.resize(HWC3(line), (W, H), interpolation=cv2.INTER_LANCZOS4)
70
+
71
+ else: # standard
72
+ guassian_sigma=6.0
73
+ intensity_threshold=8
74
+ x = detected_map.astype(np.float32)
75
+ g = cv2.GaussianBlur(x, (0, 0), guassian_sigma)
76
+ intensity = np.min(g - x, axis=2).clip(0, 255)
77
+ intensity /= max(16, np.median(intensity[intensity > intensity_threshold]))
78
+ intensity *= 127
79
+ detected_map = intensity.clip(0, 255).astype(np.uint8)
80
+
81
+ detected_map = remove_pad(255 - detected_map)
82
+ detected_map = cv2.resize(HWC3(detected_map), (W, H), interpolation=cv2.INTER_LANCZOS4)
83
+
84
+ if output_type == "pil":
85
+ detected_map = Image.fromarray(detected_map)
86
+
87
+ return detected_map
88
+
89
+ iface = gr.Interface(
90
+ fn=get_sketch,
91
+ inputs=[
92
+ gr.Image(type="numpy", label="Upload Image"),
93
+ gr.Radio(["anime", "realistic", "standard"], label="Mode", info="Process methods"),
94
+ ],
95
+ outputs=gr.Image(type="numpy", label="Sketch Output"),
96
+ title="Get a Sketch",
97
+ description="Upload an image and get a simplified sketch"
98
+ )
99
+
100
+ if __name__ == "__main__":
101
+ iface.launch()
models/netG.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ccabdcc3f5cf3c07cf65d58776acb21df7dfda825cdc70c9766a93fd62bfc488
3
+ size 217631959
models/sk_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c686ced2a666b4850b4bb6ccf0748031c3eda9f822de73a34b8979970d90f0c6
3
+ size 17173511
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ gradio
3
+ opencv-python
4
+ numpy
5
+ einops
6
+ pillow
sketch_models.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import functools
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ norm_layer = nn.InstanceNorm2d
8
+
9
+ class ResidualBlock(nn.Module):
10
+ def __init__(self, in_features):
11
+ super(ResidualBlock, self).__init__()
12
+
13
+ conv_block = [
14
+ nn.ReflectionPad2d(1),
15
+ nn.Conv2d(in_features, in_features, 3),
16
+ norm_layer(in_features),
17
+ nn.ReLU(inplace=True),
18
+ nn.ReflectionPad2d(1),
19
+ nn.Conv2d(in_features, in_features, 3),
20
+ norm_layer(in_features)
21
+ ]
22
+
23
+ self.conv_block = nn.Sequential(*conv_block)
24
+
25
+ def forward(self, x):
26
+ return x + self.conv_block(x)
27
+
28
+ class UnetSkipConnectionBlock(nn.Module):
29
+ """Defines the Unet submodule with skip connection.
30
+ X -------------------identity----------------------
31
+ |-- downsampling -- |submodule| -- upsampling --|
32
+ """
33
+
34
+ def __init__(self, outer_nc, inner_nc, input_nc=None,
35
+ submodule=None, outermost=False, innermost=False, norm_layer=nn.BatchNorm2d, use_dropout=False):
36
+ """Construct a Unet submodule with skip connections.
37
+ Parameters:
38
+ outer_nc (int) -- the number of filters in the outer conv layer
39
+ inner_nc (int) -- the number of filters in the inner conv layer
40
+ input_nc (int) -- the number of channels in input images/features
41
+ submodule (UnetSkipConnectionBlock) -- previously defined submodules
42
+ outermost (bool) -- if this module is the outermost module
43
+ innermost (bool) -- if this module is the innermost module
44
+ norm_layer -- normalization layer
45
+ use_dropout (bool) -- if use dropout layers.
46
+ """
47
+ super(UnetSkipConnectionBlock, self).__init__()
48
+ self.outermost = outermost
49
+ if type(norm_layer) == functools.partial:
50
+ use_bias = norm_layer.func == nn.InstanceNorm2d
51
+ else:
52
+ use_bias = norm_layer == nn.InstanceNorm2d
53
+ if input_nc is None:
54
+ input_nc = outer_nc
55
+ downconv = nn.Conv2d(input_nc, inner_nc, kernel_size=4,
56
+ stride=2, padding=1, bias=use_bias)
57
+ downrelu = nn.LeakyReLU(0.2, True)
58
+ downnorm = norm_layer(inner_nc)
59
+ uprelu = nn.ReLU(True)
60
+ upnorm = norm_layer(outer_nc)
61
+
62
+ if outermost:
63
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
64
+ kernel_size=4, stride=2,
65
+ padding=1)
66
+ down = [downconv]
67
+ up = [uprelu, upconv, nn.Tanh()]
68
+ model = down + [submodule] + up
69
+ elif innermost:
70
+ upconv = nn.ConvTranspose2d(inner_nc, outer_nc,
71
+ kernel_size=4, stride=2,
72
+ padding=1, bias=use_bias)
73
+ down = [downrelu, downconv]
74
+ up = [uprelu, upconv, upnorm]
75
+ model = down + up
76
+ else:
77
+ upconv = nn.ConvTranspose2d(inner_nc * 2, outer_nc,
78
+ kernel_size=4, stride=2,
79
+ padding=1, bias=use_bias)
80
+ down = [downrelu, downconv, downnorm]
81
+ up = [uprelu, upconv, upnorm]
82
+
83
+ if use_dropout:
84
+ model = down + [submodule] + up + [nn.Dropout(0.5)]
85
+ else:
86
+ model = down + [submodule] + up
87
+
88
+ self.model = nn.Sequential(*model)
89
+
90
+ def forward(self, x):
91
+ if self.outermost:
92
+ return self.model(x)
93
+ else: # add skip connections
94
+ return torch.cat([x, self.model(x)], 1)
95
+
96
+ class SimpleGenerator(nn.Module):
97
+ def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
98
+ super(SimpleGenerator, self).__init__()
99
+
100
+ # Initial convolution block
101
+ model0 = [
102
+ nn.ReflectionPad2d(3),
103
+ nn.Conv2d(input_nc, 64, 7),
104
+ norm_layer(64),
105
+ nn.ReLU(inplace=True)
106
+ ]
107
+ self.model0 = nn.Sequential(*model0)
108
+
109
+ # Downsampling
110
+ model1 = []
111
+ in_features = 64
112
+ out_features = in_features*2
113
+ for _ in range(2):
114
+ model1 += [
115
+ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
116
+ norm_layer(out_features),
117
+ nn.ReLU(inplace=True)
118
+ ]
119
+ in_features = out_features
120
+ out_features = in_features*2
121
+ self.model1 = nn.Sequential(*model1)
122
+
123
+ model2 = []
124
+ # Residual blocks
125
+ for _ in range(n_residual_blocks):
126
+ model2 += [ResidualBlock(in_features)]
127
+ self.model2 = nn.Sequential(*model2)
128
+
129
+ # Upsampling
130
+ model3 = []
131
+ out_features = in_features//2
132
+ for _ in range(2):
133
+ model3 += [
134
+ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
135
+ norm_layer(out_features),
136
+ nn.ReLU(inplace=True)
137
+ ]
138
+ in_features = out_features
139
+ out_features = in_features//2
140
+ self.model3 = nn.Sequential(*model3)
141
+
142
+ # Output layer
143
+ model4 = [
144
+ nn.ReflectionPad2d(3),
145
+ nn.Conv2d(64, output_nc, 7)
146
+ ]
147
+ if sigmoid:
148
+ model4 += [nn.Sigmoid()]
149
+
150
+ self.model4 = nn.Sequential(*model4)
151
+
152
+ def forward(self, x, cond=None):
153
+ out = self.model0(x)
154
+ out = self.model1(out)
155
+ out = self.model2(out)
156
+ out = self.model3(out)
157
+ out = self.model4(out)
158
+
159
+ return out
160
+
161
+ class UnetGenerator(nn.Module):
162
+ """Create a Unet-based generator"""
163
+
164
+ def __init__(self, input_nc, output_nc, num_downs, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False):
165
+ """Construct a Unet generator
166
+ Parameters:
167
+ input_nc (int) -- the number of channels in input images
168
+ output_nc (int) -- the number of channels in output images
169
+ num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7,
170
+ image of size 128x128 will become of size 1x1 # at the bottleneck
171
+ ngf (int) -- the number of filters in the last conv layer
172
+ norm_layer -- normalization layer
173
+ We construct the U-Net from the innermost layer to the outermost layer.
174
+ It is a recursive process.
175
+ """
176
+ super(UnetGenerator, self).__init__()
177
+ # construct unet structure
178
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=None, norm_layer=norm_layer, innermost=True) # add the innermost layer
179
+ for _ in range(num_downs - 5): # add intermediate layers with ngf * 8 filters
180
+ unet_block = UnetSkipConnectionBlock(ngf * 8, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer, use_dropout=use_dropout)
181
+ # gradually reduce the number of filters from ngf * 8 to ngf
182
+ unet_block = UnetSkipConnectionBlock(ngf * 4, ngf * 8, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
183
+ unet_block = UnetSkipConnectionBlock(ngf * 2, ngf * 4, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
184
+ unet_block = UnetSkipConnectionBlock(ngf, ngf * 2, input_nc=None, submodule=unet_block, norm_layer=norm_layer)
185
+ self.model = UnetSkipConnectionBlock(output_nc, ngf, input_nc=input_nc, submodule=unet_block, outermost=True, norm_layer=norm_layer) # add the outermost layer
186
+
187
+ def forward(self, input):
188
+ """Standard forward"""
189
+ return self.model(input)
190
+
191
+
utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import numpy as np
3
+ import cv2
4
+
5
+
6
+ UPSCALE_METHODS = ["INTER_NEAREST", "INTER_LINEAR", "INTER_AREA", "INTER_CUBIC", "INTER_LANCZOS4"]
7
+
8
+ def HWC3(x):
9
+ assert x.dtype == np.uint8
10
+ if x.ndim == 2:
11
+ x = x[:, :, None]
12
+ assert x.ndim == 3
13
+ H, W, C = x.shape
14
+ assert C == 1 or C == 3 or C == 4
15
+ if C == 3:
16
+ return x
17
+ if C == 1:
18
+ return np.concatenate([x, x, x], axis=2)
19
+ if C == 4:
20
+ color = x[:, :, 0:3].astype(np.float32)
21
+ alpha = x[:, :, 3:4].astype(np.float32) / 255.0
22
+ y = color * alpha + 255.0 * (1.0 - alpha)
23
+ y = y.clip(0, 255).astype(np.uint8)
24
+ return y
25
+
26
+ def safer_memory(x):
27
+ # Fix many MAC/AMD problems
28
+ return np.ascontiguousarray(x.copy()).copy()
29
+
30
+ def get_upscale_method(method_str):
31
+ assert method_str in UPSCALE_METHODS, f"Method {method_str} not found in {UPSCALE_METHODS}"
32
+ return getattr(cv2, method_str)
33
+
34
+ def pad64(x):
35
+ return int(np.ceil(float(x) / 64.0) * 64 - x)
36
+
37
+ # https://github.com/Mikubill/sd-webui-controlnet/blob/main/scripts/processor.py#L17
38
+ # Added upscale_method, mode params
39
+ def resize_image_with_pad(input_image, resolution, upscale_method = "", skip_hwc3=False, mode='edge'):
40
+ if skip_hwc3:
41
+ img = input_image
42
+ else:
43
+ img = HWC3(input_image)
44
+ H_raw, W_raw, _ = img.shape
45
+ if resolution == 0:
46
+ return img, lambda x: x
47
+ k = float(resolution) / float(min(H_raw, W_raw))
48
+ H_target = int(np.round(float(H_raw) * k))
49
+ W_target = int(np.round(float(W_raw) * k))
50
+ img = cv2.resize(img, (W_target, H_target), interpolation=get_upscale_method(upscale_method) if k > 1 else cv2.INTER_AREA)
51
+ H_pad, W_pad = pad64(H_target), pad64(W_target)
52
+ img_padded = np.pad(img, [[0, H_pad], [0, W_pad], [0, 0]], mode=mode)
53
+
54
+ def remove_pad(x):
55
+ return safer_memory(x[:H_target, :W_target, ...])
56
+
57
+ return safer_memory(img_padded), remove_pad
58
+
59
+ def common_input_validate(input_image, output_type, **kwargs):
60
+ if "img" in kwargs:
61
+ warnings.warn("img is deprecated, please use `input_image=...` instead.", DeprecationWarning)
62
+ input_image = kwargs.pop("img")
63
+
64
+ if "return_pil" in kwargs:
65
+ warnings.warn("return_pil is deprecated. Use output_type instead.", DeprecationWarning)
66
+ output_type = "pil" if kwargs["return_pil"] else "np"
67
+
68
+ if type(output_type) is bool:
69
+ warnings.warn("Passing `True` or `False` to `output_type` is deprecated and will raise an error in future versions")
70
+ if output_type:
71
+ output_type = "pil"
72
+
73
+ if input_image is None:
74
+ raise ValueError("input_image must be defined.")
75
+
76
+ if not isinstance(input_image, np.ndarray):
77
+ input_image = np.array(input_image, dtype=np.uint8)
78
+ output_type = output_type or "pil"
79
+ else:
80
+ output_type = output_type or "np"
81
+
82
+ return (input_image, output_type)
83
+