Zeyu Zhao commited on
Commit
4f86f8b
Β·
1 Parent(s): 1a89cb5

Init the repo.

Browse files
README.md CHANGED
@@ -1,14 +1,13 @@
1
  ---
2
  title: Sketch2Image
3
- emoji: πŸŒ–
4
- colorFrom: indigo
5
- colorTo: green
6
  sdk: gradio
7
- sdk_version: 5.16.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
- short_description: A Sketch2Image Demo
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
  title: Sketch2Image
3
+ emoji: πŸ“š
4
+ colorFrom: green
5
+ colorTo: blue
6
  sdk: gradio
7
+ sdk_version: 5.15.0
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import os
3
+ import pdb
4
+ import random
5
+ import sys
6
+ import time
7
+ from io import BytesIO
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import spaces
12
+ import torch
13
+ import torchvision.transforms.functional as TF
14
+ from PIL import Image
15
+ from torchvision import transforms
16
+
17
+ from src.img2skt import image_to_sketch_gif
18
+ from src.model import make_1step_sched
19
+ from src.pix2pix_turbo import Pix2Pix_Turbo
20
+
21
+ model = Pix2Pix_Turbo("sketch_to_image_stochastic")
22
+
23
+
24
+ style_list = [
25
+ {
26
+ "name": "No Style",
27
+ "prompt": "{prompt}",
28
+ },
29
+ {
30
+ "name": "Cinematic",
31
+ "prompt": "cinematic still {prompt} . emotional, harmonious, vignette, highly detailed, high budget, bokeh, cinemascope, moody, epic, gorgeous, film grain, grainy",
32
+ },
33
+ {
34
+ "name": "3D Model",
35
+ "prompt": "professional 3d model {prompt} . octane render, highly detailed, volumetric, dramatic lighting",
36
+ },
37
+ {
38
+ "name": "Anime",
39
+ "prompt": "anime artwork {prompt} . anime style, key visual, vibrant, studio anime, highly detailed",
40
+ },
41
+ {
42
+ "name": "Digital Art",
43
+ "prompt": "concept art {prompt} . digital artwork, illustrative, painterly, matte painting, highly detailed",
44
+ },
45
+ {
46
+ "name": "Photographic",
47
+ "prompt": "cinematic photo {prompt} . 35mm photograph, film, bokeh, professional, 4k, highly detailed",
48
+ },
49
+ {
50
+ "name": "Pixel art",
51
+ "prompt": "pixel-art {prompt} . low-res, blocky, pixel art style, 8-bit graphics",
52
+ },
53
+ {
54
+ "name": "Fantasy art",
55
+ "prompt": "ethereal fantasy concept art of {prompt} . magnificent, celestial, ethereal, painterly, epic, majestic, magical, fantasy art, cover art, dreamy",
56
+ },
57
+ {
58
+ "name": "Neonpunk",
59
+ "prompt": "neonpunk style {prompt} . cyberpunk, vaporwave, neon, vibes, vibrant, stunningly beautiful, crisp, detailed, sleek, ultramodern, magenta highlights, dark purple shadows, high contrast, cinematic, ultra detailed, intricate, professional",
60
+ },
61
+ {
62
+ "name": "Manga",
63
+ "prompt": "manga style {prompt} . vibrant, high-energy, detailed, iconic, Japanese comic style",
64
+ },
65
+ ]
66
+
67
+ styles = {k["name"]: k["prompt"] for k in style_list}
68
+ STYLE_NAMES = list(styles.keys())
69
+ DEFAULT_STYLE_NAME = "Manga"
70
+ MAX_SEED = np.iinfo(np.int32).max
71
+
72
+ HEIGHT = 512
73
+ WIDTH = 512
74
+ ITER_DELAY = 1.0
75
+
76
+
77
+ # Create a white background image
78
+ def create_white_background(width, height):
79
+ return Image.new("RGB", (width, height), color="white")
80
+
81
+
82
+ white_background = create_white_background(WIDTH, HEIGHT)
83
+
84
+
85
+ @spaces.GPU(duration=45)
86
+ def run(image, prompt, prompt_template, style_name, seed, val_r):
87
+
88
+ image = image["composite"]
89
+ prompt = prompt_template.replace("{prompt}", prompt)
90
+ image = image.convert("RGB")
91
+ image = Image.fromarray(255 - np.array(image))
92
+ image_t = TF.to_tensor(image) > 0.5
93
+
94
+ with torch.no_grad():
95
+ c_t = image_t.unsqueeze(0).cuda().float()
96
+ torch.manual_seed(seed)
97
+ B, C, H, W = c_t.shape
98
+ noise = torch.randn((1, 4, H // 8, W // 8), device=c_t.device)
99
+ output_image = model(c_t, prompt, deterministic=False, r=val_r, noise_map=noise)
100
+ output_pil = TF.to_pil_image(output_image[0].cpu() * 0.5 + 0.5)
101
+ return output_pil
102
+
103
+
104
+ def clear_image_editor():
105
+ return (
106
+ {"background": white_background, "layers": None, "composite": white_background},
107
+ gr.Image(
108
+ value=None,
109
+ height=HEIGHT,
110
+ width=WIDTH,
111
+ elem_id="output_image",
112
+ type="pil",
113
+ show_label=False,
114
+ show_download_button=True,
115
+ interactive=False,
116
+ ),
117
+ gr.Image(
118
+ value=None,
119
+ height=HEIGHT,
120
+ width=WIDTH,
121
+ show_label=False,
122
+ show_download_button=True,
123
+ type="pil",
124
+ interactive=False,
125
+ ),
126
+ gr.Image(
127
+ value=None,
128
+ height=HEIGHT,
129
+ width=WIDTH,
130
+ show_label=False,
131
+ show_download_button=True,
132
+ type="pil",
133
+ interactive=False,
134
+ ),
135
+ gr.State([]),
136
+ gr.Slider(
137
+ minimum=0,
138
+ maximum=1,
139
+ value=0,
140
+ step=1,
141
+ visible=False,
142
+ scale=4,
143
+ label="Frame Selector",
144
+ ),
145
+ gr.Button("Stop", scale=1, visible=True),
146
+ )
147
+
148
+
149
+ def iter_frames(frames):
150
+ for frame in frames:
151
+ time.sleep(ITER_DELAY)
152
+ yield frame
153
+
154
+
155
+ def apply_func_click():
156
+ return gr.Slider(
157
+ visible=True,
158
+ )
159
+
160
+
161
+ with gr.Blocks() as demo:
162
+
163
+ with gr.Row():
164
+ with gr.Column():
165
+ # gr.Markdown("## INPUT", elem_id="input_header")
166
+ with gr.Row():
167
+ image = gr.Sketchpad(
168
+ value={
169
+ "background": white_background,
170
+ "layers": None,
171
+ "composite": white_background,
172
+ },
173
+ image_mode="L",
174
+ type="pil",
175
+ sources=None,
176
+ # container=True,
177
+ label="Sketch",
178
+ show_label=True,
179
+ show_download_button=True,
180
+ # show_share_button=True,
181
+ interactive=True,
182
+ layers=False,
183
+ canvas_size=(WIDTH, HEIGHT),
184
+ show_fullscreen_button=False,
185
+ brush=gr.Brush(
186
+ colors=["#000000", "#FFFFFF"],
187
+ color_mode="fixed",
188
+ default_size=4,
189
+ ),
190
+ )
191
+
192
+ with gr.Row():
193
+ prompt = gr.Textbox(label="Prompt", value="", show_label=True)
194
+ with gr.Row():
195
+ run_button = gr.Button("Run", scale=1)
196
+ randomize_seed = gr.Button("Random", scale=1)
197
+ with gr.Row():
198
+ apply_button = gr.Button("Stop", scale=1, visible=True)
199
+ with gr.Row():
200
+ frame_selector = gr.Slider(
201
+ minimum=0,
202
+ maximum=1,
203
+ value=0,
204
+ step=1,
205
+ visible=False,
206
+ scale=4,
207
+ label="Frame Selector",
208
+ )
209
+
210
+ with gr.Row():
211
+ style = gr.Dropdown(
212
+ label="Style",
213
+ choices=STYLE_NAMES,
214
+ value=DEFAULT_STYLE_NAME,
215
+ scale=1,
216
+ visible=False,
217
+ )
218
+ prompt_temp = gr.Textbox(
219
+ label="Prompt Style Template",
220
+ value=styles[DEFAULT_STYLE_NAME],
221
+ max_lines=1,
222
+ scale=2,
223
+ visible=False,
224
+ )
225
+
226
+ with gr.Row():
227
+ val_r = gr.Slider(
228
+ label="Sketch guidance: ",
229
+ show_label=True,
230
+ minimum=0,
231
+ maximum=1,
232
+ value=0.4,
233
+ step=0.01,
234
+ scale=4,
235
+ visible=False,
236
+ )
237
+ seed = gr.Textbox(label="Seed", value=42, scale=4, visible=False)
238
+
239
+ with gr.Column():
240
+ # gr.Markdown("## OUTPUT", elem_id="output_header")
241
+ result = gr.Image(
242
+ height=HEIGHT,
243
+ width=WIDTH,
244
+ elem_id="output_image",
245
+ type="pil",
246
+ show_label=False,
247
+ show_download_button=True,
248
+ interactive=False,
249
+ visible=False,
250
+ )
251
+
252
+ gr.Markdown("### Instructions")
253
+ gr.Markdown("1. Enter a text prompt (e.g. cat)")
254
+ gr.Markdown("2. Draw some sketches on the Sketchpad")
255
+ gr.Markdown("3. Click on **Run** to generate the skecthes powered by AI")
256
+ gr.Markdown(
257
+ "4. While you see the sketches coming out, click on **Stop** to stop more frames coming out"
258
+ )
259
+ gr.Markdown("5. Then you can select a frame by the Frame Selector")
260
+ gr.Markdown(
261
+ "6. You may then continue to draw more sketches or change the prompt and repeat the process"
262
+ )
263
+ gr.Markdown(
264
+ "7. You may try different random seeds by clicking on **Random**"
265
+ )
266
+ gr.Markdown(
267
+ "**Thanks to the [paper](https://arxiv.org/abs/2403.12036) and their open-sourced models!**"
268
+ )
269
+ frames = gr.State([])
270
+ sketches = gr.Image(
271
+ height=HEIGHT,
272
+ width=WIDTH,
273
+ show_label=False,
274
+ show_download_button=True,
275
+ type="pil",
276
+ visible=False,
277
+ )
278
+ one_frame = gr.Image(
279
+ height=HEIGHT,
280
+ width=WIDTH,
281
+ show_label=False,
282
+ show_download_button=True,
283
+ type="pil",
284
+ interactive=False,
285
+ visible=False,
286
+ )
287
+
288
+ inputs = [image, prompt, prompt_temp, style, seed, val_r]
289
+ outputs = [result]
290
+
291
+ randomize_seed_click = (
292
+ randomize_seed.click(
293
+ lambda: random.randint(0, MAX_SEED),
294
+ inputs=[],
295
+ outputs=seed,
296
+ )
297
+ .then(
298
+ fn=run,
299
+ inputs=inputs,
300
+ outputs=outputs,
301
+ )
302
+ .then(
303
+ image_to_sketch_gif,
304
+ inputs=[result],
305
+ outputs=[sketches, frames, frame_selector, apply_button],
306
+ )
307
+ .then(
308
+ iter_frames,
309
+ inputs=[frames],
310
+ outputs=[image],
311
+ )
312
+ )
313
+
314
+ prompt_submit = (
315
+ prompt.submit(fn=run, inputs=inputs, outputs=outputs)
316
+ .then(
317
+ image_to_sketch_gif,
318
+ inputs=[result],
319
+ outputs=[sketches, frames, frame_selector, apply_button],
320
+ )
321
+ .then(
322
+ iter_frames,
323
+ inputs=[frames],
324
+ outputs=[image],
325
+ )
326
+ )
327
+
328
+ style_change = (
329
+ style.change(lambda x: styles[x], inputs=[style], outputs=[prompt_temp])
330
+ .then(
331
+ fn=run,
332
+ inputs=inputs,
333
+ outputs=outputs,
334
+ )
335
+ .then(
336
+ image_to_sketch_gif,
337
+ inputs=[result],
338
+ outputs=[sketches, frames, frame_selector, apply_button],
339
+ )
340
+ .then(
341
+ iter_frames,
342
+ inputs=[frames],
343
+ outputs=[image],
344
+ )
345
+ )
346
+
347
+ val_r_change = (
348
+ val_r.change(run, inputs=inputs, outputs=outputs)
349
+ .then(
350
+ image_to_sketch_gif,
351
+ inputs=[result],
352
+ outputs=[sketches, frames, frame_selector, apply_button],
353
+ )
354
+ .then(
355
+ iter_frames,
356
+ inputs=[frames],
357
+ outputs=[image],
358
+ )
359
+ )
360
+
361
+ run_button_click = (
362
+ run_button.click(fn=run, inputs=inputs, outputs=outputs)
363
+ .then(
364
+ image_to_sketch_gif,
365
+ inputs=[result],
366
+ outputs=[sketches, frames, frame_selector, apply_button],
367
+ )
368
+ .then(
369
+ iter_frames,
370
+ inputs=[frames],
371
+ outputs=[image],
372
+ )
373
+ )
374
+
375
+ image_apply = (
376
+ image.apply(
377
+ run,
378
+ inputs=inputs,
379
+ outputs=outputs,
380
+ )
381
+ .then(
382
+ image_to_sketch_gif,
383
+ inputs=[result],
384
+ outputs=[sketches, frames, frame_selector, apply_button],
385
+ )
386
+ .then(
387
+ iter_frames,
388
+ inputs=[frames],
389
+ outputs=[image],
390
+ )
391
+ )
392
+
393
+ apply_button.click(
394
+ fn=None,
395
+ inputs=None,
396
+ outputs=None,
397
+ cancels=[
398
+ run_button_click,
399
+ randomize_seed_click,
400
+ prompt_submit,
401
+ style_change,
402
+ val_r_change,
403
+ image_apply,
404
+ ],
405
+ )
406
+ apply_button.click(
407
+ fn=apply_func_click,
408
+ inputs=None,
409
+ outputs=[frame_selector],
410
+ )
411
+
412
+ frame_selector.change(
413
+ lambda x, y: y[x], inputs=[frame_selector, frames], outputs=[image]
414
+ )
415
+
416
+ image.clear(
417
+ fn=None,
418
+ inputs=None,
419
+ outputs=None,
420
+ cancels=[
421
+ run_button_click,
422
+ randomize_seed_click,
423
+ prompt_submit,
424
+ style_change,
425
+ val_r_change,
426
+ image_apply,
427
+ ],
428
+ )
429
+ image.clear(
430
+ fn=clear_image_editor,
431
+ inputs=None,
432
+ outputs=[
433
+ image,
434
+ result,
435
+ sketches,
436
+ one_frame,
437
+ frames,
438
+ frame_selector,
439
+ apply_button,
440
+ ],
441
+ )
442
+
443
+ if __name__ == "__main__":
444
+ demo.queue().launch(debug=True)
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ clip @ git+https://github.com/openai/CLIP.git
2
+ einops>=0.6.1
3
+ numpy>=1.24.4
4
+ open-clip-torch>=2.20.0
5
+ opencv-python==4.6.0.66
6
+ pillow>=9.5.0
7
+ scipy==1.11.1
8
+ timm>=0.9.2
9
+ tokenizers
10
+ torch>=2.1.0
11
+
12
+ torchaudio>=2.0.2
13
+ torchdata
14
+ torchmetrics>=1.0.1
15
+ torchvision>=0.15.2
16
+
17
+ tqdm>=4.65.0
18
+ transformers==4.43.2
19
+ triton
20
+ urllib3<1.27,>=1.25.4
21
+ xformers>=0.0.20
22
+ accelerate
23
+ streamlit-keyup==0.2.0
24
+ lpips
25
+ clean-fid
26
+ peft
27
+ dominate
28
+ diffusers>=0.25.1
29
+ huggingface_hub>=0.26.0
30
+ hf_transfer
src/__pycache__/image_prep.cpython-310.pyc ADDED
Binary file (544 Bytes). View file
 
src/__pycache__/img2skt.cpython-312.pyc ADDED
Binary file (3.13 kB). View file
 
src/__pycache__/model.cpython-310.pyc ADDED
Binary file (699 Bytes). View file
 
src/__pycache__/model.cpython-312.pyc ADDED
Binary file (3.06 kB). View file
 
src/__pycache__/pix2pix_turbo.cpython-310.pyc ADDED
Binary file (6.84 kB). View file
 
src/__pycache__/pix2pix_turbo.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
src/image_prep.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from PIL import Image
3
+ import cv2
4
+
5
+
6
+ def canny_from_pil(image, low_threshold=100, high_threshold=200):
7
+ image = np.array(image)
8
+ image = cv2.Canny(image, low_threshold, high_threshold)
9
+ image = image[:, :, None]
10
+ image = np.concatenate([image, image, image], axis=2)
11
+ control_image = Image.fromarray(image)
12
+ return control_image
src/img2skt.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import os
3
+ import tempfile
4
+
5
+ import cv2
6
+ import gradio as gr
7
+ import numpy as np
8
+ from PIL import Image, ImageSequence
9
+
10
+
11
+ def image_to_sketch_gif(input_image: Image.Image):
12
+ # Convert PIL image to OpenCV format
13
+ open_cv_image = np.array(input_image.convert("RGB"))
14
+ open_cv_image = cv2.cvtColor(open_cv_image, cv2.COLOR_RGB2BGR)
15
+
16
+ # Convert to grayscale
17
+ grayscale_image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2GRAY)
18
+
19
+ # Apply Gaussian blur
20
+ blurred_image = cv2.GaussianBlur(grayscale_image, (5, 5), 0)
21
+
22
+ # Use Canny Edge Detection
23
+ edges = cv2.Canny(blurred_image, threshold1=50, threshold2=150)
24
+
25
+ # Ensure binary format
26
+ _, binary_sketch = cv2.threshold(edges, 128, 255, cv2.THRESH_BINARY)
27
+
28
+ # Find connected components
29
+ num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(
30
+ binary_sketch, connectivity=8
31
+ )
32
+
33
+ # Sort components by size (excluding the background, which is label 0)
34
+ components = sorted(
35
+ [(i, stats[i, cv2.CC_STAT_AREA]) for i in range(1, num_labels)],
36
+ key=lambda x: x[1],
37
+ reverse=True,
38
+ )
39
+
40
+ # Initialize an empty canvas for accumulation
41
+ accumulated_image = np.zeros_like(binary_sketch, dtype=np.uint8)
42
+
43
+ # Store frames
44
+ frames = []
45
+
46
+ for label, _ in components:
47
+ # Add the current component to the accumulation
48
+ accumulated_image[labels == label] = 255
49
+
50
+ # Convert OpenCV image to PIL image and append to frames
51
+ pil_frame = Image.fromarray(255 - accumulated_image)
52
+ frames.append(pil_frame.copy())
53
+
54
+ # Add the input_input as the final frame
55
+ frames.append(input_image.copy())
56
+
57
+ # Save GIF to a temporary file
58
+ tmp_dir = tempfile.gettempdir() # Get system temp directory
59
+ tmp_gif_path = os.path.join(tmp_dir, "sketch_animation.gif")
60
+ frames[0].save(
61
+ tmp_gif_path,
62
+ format="GIF",
63
+ save_all=True,
64
+ append_images=frames[1:],
65
+ duration=100,
66
+ loop=0,
67
+ )
68
+
69
+ return (
70
+ tmp_gif_path,
71
+ frames,
72
+ gr.Slider(
73
+ minimum=0,
74
+ maximum=len(frames) - 1,
75
+ value=0,
76
+ step=1,
77
+ visible=False,
78
+ scale=4,
79
+ label="Frame Selector",
80
+ interactive=True,
81
+ ),
82
+ gr.Button("Stop", scale=1, visible=True),
83
+ )
src/model.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, sys, pdb
2
+
3
+ import diffusers
4
+ from transformers import AutoTokenizer, PretrainedConfig
5
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
6
+
7
+
8
+ def make_1step_sched():
9
+ noise_scheduler = DDPMScheduler.from_pretrained(
10
+ "stabilityai/sd-turbo", subfolder="scheduler"
11
+ )
12
+ noise_scheduler_1step = DDPMScheduler.from_pretrained(
13
+ "stabilityai/sd-turbo", subfolder="scheduler"
14
+ )
15
+ noise_scheduler_1step.set_timesteps(1, device="cuda")
16
+ noise_scheduler_1step.alphas_cumprod = noise_scheduler_1step.alphas_cumprod.cuda()
17
+ return noise_scheduler_1step
18
+
19
+
20
+ """The forward method of the `Encoder` class."""
21
+
22
+
23
+ def my_vae_encoder_fwd(self, sample):
24
+ sample = self.conv_in(sample)
25
+ l_blocks = []
26
+ # down
27
+ for down_block in self.down_blocks:
28
+ l_blocks.append(sample)
29
+ sample = down_block(sample)
30
+ # middle
31
+ sample = self.mid_block(sample)
32
+ sample = self.conv_norm_out(sample)
33
+ sample = self.conv_act(sample)
34
+ sample = self.conv_out(sample)
35
+ self.current_down_blocks = l_blocks
36
+ return sample
37
+
38
+
39
+ """The forward method of the `Decoder` class."""
40
+
41
+
42
+ def my_vae_decoder_fwd(self, sample, latent_embeds=None):
43
+ sample = self.conv_in(sample)
44
+ upscale_dtype = next(iter(self.up_blocks.parameters())).dtype
45
+ # middle
46
+ sample = self.mid_block(sample, latent_embeds)
47
+ sample = sample.to(upscale_dtype)
48
+ if not self.ignore_skip:
49
+ skip_convs = [
50
+ self.skip_conv_1,
51
+ self.skip_conv_2,
52
+ self.skip_conv_3,
53
+ self.skip_conv_4,
54
+ ]
55
+ # up
56
+ for idx, up_block in enumerate(self.up_blocks):
57
+ skip_in = skip_convs[idx](self.incoming_skip_acts[::-1][idx] * self.gamma)
58
+ # add skip
59
+ sample = sample + skip_in
60
+ sample = up_block(sample, latent_embeds)
61
+ else:
62
+ for idx, up_block in enumerate(self.up_blocks):
63
+ sample = up_block(sample, latent_embeds)
64
+ # post-process
65
+ if latent_embeds is None:
66
+ sample = self.conv_norm_out(sample)
67
+ else:
68
+ sample = self.conv_norm_out(sample, latent_embeds)
69
+ sample = self.conv_act(sample)
70
+ sample = self.conv_out(sample)
71
+ return sample
src/pix2pix_turbo.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import os
3
+ import requests
4
+ import sys
5
+ import pdb
6
+ import copy
7
+ from tqdm import tqdm
8
+ import torch
9
+ from transformers import AutoTokenizer, PretrainedConfig, CLIPTextModel
10
+ from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler
11
+ from diffusers.utils.peft_utils import set_weights_and_activate_adapters
12
+ from peft import LoraConfig
13
+
14
+ p = "src/"
15
+ sys.path.append(p)
16
+ from model import make_1step_sched, my_vae_encoder_fwd, my_vae_decoder_fwd
17
+
18
+
19
+ class TwinConv(torch.nn.Module):
20
+ def __init__(self, convin_pretrained, convin_curr):
21
+ super(TwinConv, self).__init__()
22
+ self.conv_in_pretrained = copy.deepcopy(convin_pretrained)
23
+ self.conv_in_curr = copy.deepcopy(convin_curr)
24
+ self.r = None
25
+
26
+ def forward(self, x):
27
+ x1 = self.conv_in_pretrained(x).detach()
28
+ x2 = self.conv_in_curr(x)
29
+ return x1 * (1 - self.r) + x2 * (self.r)
30
+
31
+
32
+ class Pix2Pix_Turbo(torch.nn.Module):
33
+ def __init__(self, name, ckpt_folder="checkpoints"):
34
+ super().__init__()
35
+ self.tokenizer = AutoTokenizer.from_pretrained(
36
+ "stabilityai/sd-turbo", subfolder="tokenizer"
37
+ )
38
+ self.text_encoder = CLIPTextModel.from_pretrained(
39
+ "stabilityai/sd-turbo", subfolder="text_encoder"
40
+ ).cuda()
41
+ self.sched = make_1step_sched()
42
+
43
+ vae = AutoencoderKL.from_pretrained("stabilityai/sd-turbo", subfolder="vae")
44
+ unet = UNet2DConditionModel.from_pretrained(
45
+ "stabilityai/sd-turbo", subfolder="unet"
46
+ )
47
+
48
+ if name == "edge_to_image":
49
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/edge_to_image_loras.pkl"
50
+ os.makedirs(ckpt_folder, exist_ok=True)
51
+ outf = os.path.join(ckpt_folder, "edge_to_image_loras.pkl")
52
+ if not os.path.exists(outf):
53
+ print(f"Downloading checkpoint to {outf}")
54
+ response = requests.get(url, stream=True)
55
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
56
+ block_size = 1024 # 1 Kibibyte
57
+ progress_bar = tqdm(
58
+ total=total_size_in_bytes, unit="iB", unit_scale=True
59
+ )
60
+ with open(outf, "wb") as file:
61
+ for data in response.iter_content(block_size):
62
+ progress_bar.update(len(data))
63
+ file.write(data)
64
+ progress_bar.close()
65
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
66
+ print("ERROR, something went wrong")
67
+ print(f"Downloaded successfully to {outf}")
68
+ p_ckpt = outf
69
+ sd = torch.load(p_ckpt, map_location="cpu")
70
+ unet_lora_config = LoraConfig(
71
+ r=sd["rank_unet"],
72
+ init_lora_weights="gaussian",
73
+ target_modules=sd["unet_lora_target_modules"],
74
+ )
75
+
76
+ if name == "sketch_to_image_stochastic":
77
+ # download from url
78
+ url = "https://www.cs.cmu.edu/~img2img-turbo/models/sketch_to_image_stochastic_lora.pkl"
79
+ os.makedirs(ckpt_folder, exist_ok=True)
80
+ outf = os.path.join(ckpt_folder, "sketch_to_image_stochastic_lora.pkl")
81
+ if not os.path.exists(outf):
82
+ print(f"Downloading checkpoint to {outf}")
83
+ response = requests.get(url, stream=True)
84
+ total_size_in_bytes = int(response.headers.get("content-length", 0))
85
+ block_size = 1024 # 1 Kibibyte
86
+ progress_bar = tqdm(
87
+ total=total_size_in_bytes, unit="iB", unit_scale=True
88
+ )
89
+ with open(outf, "wb") as file:
90
+ for data in response.iter_content(block_size):
91
+ progress_bar.update(len(data))
92
+ file.write(data)
93
+ progress_bar.close()
94
+ if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes:
95
+ print("ERROR, something went wrong")
96
+ print(f"Downloaded successfully to {outf}")
97
+ p_ckpt = outf
98
+ sd = torch.load(p_ckpt, map_location="cpu")
99
+ unet_lora_config = LoraConfig(
100
+ r=sd["rank_unet"],
101
+ init_lora_weights="gaussian",
102
+ target_modules=sd["unet_lora_target_modules"],
103
+ )
104
+ convin_pretrained = copy.deepcopy(unet.conv_in)
105
+ unet.conv_in = TwinConv(convin_pretrained, unet.conv_in)
106
+
107
+ vae.encoder.forward = my_vae_encoder_fwd.__get__(
108
+ vae.encoder, vae.encoder.__class__
109
+ )
110
+ vae.decoder.forward = my_vae_decoder_fwd.__get__(
111
+ vae.decoder, vae.decoder.__class__
112
+ )
113
+ # add the skip connection convs
114
+ vae.decoder.skip_conv_1 = torch.nn.Conv2d(
115
+ 512, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
116
+ ).cuda()
117
+ vae.decoder.skip_conv_2 = torch.nn.Conv2d(
118
+ 256, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
119
+ ).cuda()
120
+ vae.decoder.skip_conv_3 = torch.nn.Conv2d(
121
+ 128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False
122
+ ).cuda()
123
+ vae.decoder.skip_conv_4 = torch.nn.Conv2d(
124
+ 128, 256, kernel_size=(1, 1), stride=(1, 1), bias=False
125
+ ).cuda()
126
+ vae_lora_config = LoraConfig(
127
+ r=sd["rank_vae"],
128
+ init_lora_weights="gaussian",
129
+ target_modules=sd["vae_lora_target_modules"],
130
+ )
131
+ vae.decoder.ignore_skip = False
132
+ vae.add_adapter(vae_lora_config, adapter_name="vae_skip")
133
+ unet.add_adapter(unet_lora_config)
134
+ _sd_unet = unet.state_dict()
135
+ for k in sd["state_dict_unet"]:
136
+ _sd_unet[k] = sd["state_dict_unet"][k]
137
+ unet.load_state_dict(_sd_unet)
138
+
139
+ @spaces.GPU()
140
+ def wrapper(unet):
141
+ unet.enable_xformers_memory_efficient_attention()
142
+ return unet
143
+
144
+ unet = wrapper(unet)
145
+ _sd_vae = vae.state_dict()
146
+ for k in sd["state_dict_vae"]:
147
+ _sd_vae[k] = sd["state_dict_vae"][k]
148
+ vae.load_state_dict(_sd_vae)
149
+ unet.to("cuda")
150
+ vae.to("cuda")
151
+ unet.eval()
152
+ vae.eval()
153
+ self.unet, self.vae = unet, vae
154
+ self.vae.decoder.gamma = 1
155
+ self.timesteps = torch.tensor([999], device="cuda").long()
156
+
157
+ def forward(self, c_t, prompt, deterministic=True, r=1.0, noise_map=None):
158
+ # encode the text prompt
159
+ caption_tokens = self.tokenizer(
160
+ prompt,
161
+ max_length=self.tokenizer.model_max_length,
162
+ padding="max_length",
163
+ truncation=True,
164
+ return_tensors="pt",
165
+ ).input_ids.cuda()
166
+ caption_enc = self.text_encoder(caption_tokens)[0]
167
+ if deterministic:
168
+ encoded_control = (
169
+ self.vae.encode(c_t).latent_dist.sample()
170
+ * self.vae.config.scaling_factor
171
+ )
172
+ model_pred = self.unet(
173
+ encoded_control,
174
+ self.timesteps,
175
+ encoder_hidden_states=caption_enc,
176
+ ).sample
177
+ x_denoised = self.sched.step(
178
+ model_pred, self.timesteps, encoded_control, return_dict=True
179
+ ).prev_sample
180
+ self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
181
+ output_image = (
182
+ self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
183
+ ).clamp(-1, 1)
184
+ else:
185
+ # scale the lora weights based on the r value
186
+ self.unet.set_adapters(["default"], weights=[r])
187
+ set_weights_and_activate_adapters(self.vae, ["vae_skip"], [r])
188
+ encoded_control = (
189
+ self.vae.encode(c_t).latent_dist.sample()
190
+ * self.vae.config.scaling_factor
191
+ )
192
+ # combine the input and noise
193
+ unet_input = encoded_control * r + noise_map * (1 - r)
194
+ self.unet.conv_in.r = r
195
+ unet_output = self.unet(
196
+ unet_input,
197
+ self.timesteps,
198
+ encoder_hidden_states=caption_enc,
199
+ ).sample
200
+ self.unet.conv_in.r = None
201
+ x_denoised = self.sched.step(
202
+ unet_output, self.timesteps, unet_input, return_dict=True
203
+ ).prev_sample
204
+ self.vae.decoder.incoming_skip_acts = self.vae.encoder.current_down_blocks
205
+ self.vae.decoder.gamma = r
206
+ output_image = (
207
+ self.vae.decode(x_denoised / self.vae.config.scaling_factor).sample
208
+ ).clamp(-1, 1)
209
+ return output_image