nvn04 commited on
Commit
22b7932
·
verified ·
1 Parent(s): f717329

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +484 -508
app.py CHANGED
@@ -1,508 +1,484 @@
1
- import argparse
2
- import os
3
- os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
- os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
- from datetime import datetime
6
-
7
- import gradio as gr
8
- import spaces
9
- import numpy as np
10
- import torch
11
- from diffusers.image_processor import VaeImageProcessor
12
- from huggingface_hub import snapshot_download
13
- from PIL import Image
14
- torch.jit.script = lambda f: f
15
- from model.cloth_masker import AutoMasker, vis_mask
16
- from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
- from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
- from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
-
20
-
21
- def parse_args():
22
- parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
- parser.add_argument(
24
- "--base_model_path",
25
- type=str,
26
- default="booksforcharlie/stable-diffusion-inpainting",
27
- help=(
28
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
- ),
30
- )
31
- parser.add_argument(
32
- "--p2p_base_model_path",
33
- type=str,
34
- default="timbrooks/instruct-pix2pix",
35
- help=(
36
- "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
- ),
38
- )
39
- parser.add_argument(
40
- "--resume_path",
41
- type=str,
42
- default="zhengchong/CatVTON",
43
- help=(
44
- "The Path to the checkpoint of trained tryon model."
45
- ),
46
- )
47
- parser.add_argument(
48
- "--output_dir",
49
- type=str,
50
- default="resource/demo/output",
51
- help="The output directory where the model predictions will be written.",
52
- )
53
-
54
- parser.add_argument(
55
- "--width",
56
- type=int,
57
- default=768,
58
- help=(
59
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
- " resolution"
61
- ),
62
- )
63
- parser.add_argument(
64
- "--height",
65
- type=int,
66
- default=1024,
67
- help=(
68
- "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
- " resolution"
70
- ),
71
- )
72
- parser.add_argument(
73
- "--repaint",
74
- action="store_true",
75
- help="Whether to repaint the result image with the original background."
76
- )
77
- parser.add_argument(
78
- "--allow_tf32",
79
- action="store_true",
80
- default=True,
81
- help=(
82
- "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
- " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
- ),
85
- )
86
- parser.add_argument(
87
- "--mixed_precision",
88
- type=str,
89
- default="bf16",
90
- choices=["no", "fp16", "bf16"],
91
- help=(
92
- "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
- " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
- " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
- ),
96
- )
97
-
98
- args = parser.parse_args()
99
- env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
- if env_local_rank != -1 and env_local_rank != args.local_rank:
101
- args.local_rank = env_local_rank
102
-
103
- return args
104
-
105
- def image_grid(imgs, rows, cols):
106
- assert len(imgs) == rows * cols
107
-
108
- w, h = imgs[0].size
109
- grid = Image.new("RGB", size=(cols * w, rows * h))
110
-
111
- for i, img in enumerate(imgs):
112
- grid.paste(img, box=(i % cols * w, i // cols * h))
113
- return grid
114
-
115
-
116
- args = parse_args()
117
-
118
- # Mask-based CatVTON
119
- catvton_repo = "zhengchong/CatVTON"
120
- repo_path = snapshot_download(repo_id=catvton_repo)
121
- # Pipeline
122
- pipeline = CatVTONPipeline(
123
- base_ckpt=args.base_model_path,
124
- attn_ckpt=repo_path,
125
- attn_ckpt_version="mix",
126
- weight_dtype=init_weight_dtype(args.mixed_precision),
127
- use_tf32=args.allow_tf32,
128
- device='cuda'
129
- )
130
- # AutoMasker
131
- mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
- automasker = AutoMasker(
133
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
- schp_ckpt=os.path.join(repo_path, "SCHP"),
135
- device='cuda',
136
- )
137
-
138
-
139
- # Flux-based CatVTON
140
- access_token = os.getenv("HUGGING_FACE_HUB_TOKEN")
141
- flux_repo = "black-forest-labs/FLUX.1-Fill-dev"
142
- pipeline_flux = FluxTryOnPipeline.from_pretrained(flux_repo, use_auth_token=access_token)
143
- pipeline_flux.load_lora_weights(
144
- os.path.join(repo_path, "flux-lora"),
145
- weight_name='pytorch_lora_weights.safetensors'
146
- )
147
- pipeline_flux.to("cuda", init_weight_dtype(args.mixed_precision))
148
-
149
-
150
- # Mask-free CatVTON
151
- catvton_mf_repo = "zhengchong/CatVTON-MaskFree"
152
- repo_path_mf = snapshot_download(repo_id=catvton_mf_repo, use_auth_token=access_token)
153
- pipeline_p2p = CatVTONPix2PixPipeline(
154
- base_ckpt=args.p2p_base_model_path,
155
- attn_ckpt=repo_path_mf,
156
- attn_ckpt_version="mix-48k-1024",
157
- weight_dtype=init_weight_dtype(args.mixed_precision),
158
- use_tf32=args.allow_tf32,
159
- device='cuda'
160
- )
161
-
162
-
163
- @spaces.GPU(duration=120)
164
- def submit_function(
165
- person_image,
166
- cloth_image,
167
- cloth_type,
168
- num_inference_steps,
169
- guidance_scale,
170
- seed,
171
- show_type
172
- ):
173
- person_image, mask = person_image["background"], person_image["layers"][0]
174
- mask = Image.open(mask).convert("L")
175
- if len(np.unique(np.array(mask))) == 1:
176
- mask = None
177
- else:
178
- mask = np.array(mask)
179
- mask[mask > 0] = 255
180
- mask = Image.fromarray(mask)
181
-
182
- tmp_folder = args.output_dir
183
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
184
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
185
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
186
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
187
-
188
- generator = None
189
- if seed != -1:
190
- generator = torch.Generator(device='cuda').manual_seed(seed)
191
-
192
- person_image = Image.open(person_image).convert("RGB")
193
- cloth_image = Image.open(cloth_image).convert("RGB")
194
- person_image = resize_and_crop(person_image, (args.width, args.height))
195
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
196
-
197
- # Process mask
198
- if mask is not None:
199
- mask = resize_and_crop(mask, (args.width, args.height))
200
- else:
201
- mask = automasker(
202
- person_image,
203
- cloth_type
204
- )['mask']
205
- mask = mask_processor.blur(mask, blur_factor=9)
206
-
207
- # Inference
208
- # try:
209
- result_image = pipeline(
210
- image=person_image,
211
- condition_image=cloth_image,
212
- mask=mask,
213
- num_inference_steps=num_inference_steps,
214
- guidance_scale=guidance_scale,
215
- generator=generator
216
- )[0]
217
- # except Exception as e:
218
- # raise gr.Error(
219
- # "An error occurred. Please try again later: {}".format(e)
220
- # )
221
-
222
- # Post-process
223
- masked_person = vis_mask(person_image, mask)
224
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
225
- save_result_image.save(result_save_path)
226
- if show_type == "result only":
227
- return result_image
228
- else:
229
- width, height = person_image.size
230
- if show_type == "input & result":
231
- condition_width = width // 2
232
- conditions = image_grid([person_image, cloth_image], 2, 1)
233
- else:
234
- condition_width = width // 3
235
- conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
236
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
237
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
238
- new_result_image.paste(conditions, (0, 0))
239
- new_result_image.paste(result_image, (condition_width + 5, 0))
240
- return new_result_image
241
-
242
- @spaces.GPU(duration=120)
243
- def submit_function_p2p(
244
- person_image,
245
- cloth_image,
246
- num_inference_steps,
247
- guidance_scale,
248
- seed):
249
- person_image= person_image["background"]
250
-
251
- tmp_folder = args.output_dir
252
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
253
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
254
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
255
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
256
-
257
- generator = None
258
- if seed != -1:
259
- generator = torch.Generator(device='cuda').manual_seed(seed)
260
-
261
- person_image = Image.open(person_image).convert("RGB")
262
- cloth_image = Image.open(cloth_image).convert("RGB")
263
- person_image = resize_and_crop(person_image, (args.width, args.height))
264
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
265
-
266
- # Inference
267
- try:
268
- result_image = pipeline_p2p(
269
- image=person_image,
270
- condition_image=cloth_image,
271
- num_inference_steps=num_inference_steps,
272
- guidance_scale=guidance_scale,
273
- generator=generator
274
- )[0]
275
- except Exception as e:
276
- raise gr.Error(
277
- "An error occurred. Please try again later: {}".format(e)
278
- )
279
-
280
- # Post-process
281
- save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
282
- save_result_image.save(result_save_path)
283
- return result_image
284
-
285
- @spaces.GPU(duration=120)
286
- def submit_function_flux(
287
- person_image,
288
- cloth_image,
289
- cloth_type,
290
- num_inference_steps,
291
- guidance_scale,
292
- seed,
293
- show_type
294
- ):
295
-
296
- # Process image editor input
297
- person_image, mask = person_image["background"], person_image["layers"][0]
298
- mask = Image.open(mask).convert("L")
299
- if len(np.unique(np.array(mask))) == 1:
300
- mask = None
301
- else:
302
- mask = np.array(mask)
303
- mask[mask > 0] = 255
304
- mask = Image.fromarray(mask)
305
-
306
- # Set random seed
307
- generator = None
308
- if seed != -1:
309
- generator = torch.Generator(device='cuda').manual_seed(seed)
310
-
311
- # Process input images
312
- person_image = Image.open(person_image).convert("RGB")
313
- cloth_image = Image.open(cloth_image).convert("RGB")
314
-
315
- # Adjust image sizes
316
- person_image = resize_and_crop(person_image, (args.width, args.height))
317
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
318
-
319
- # Process mask
320
- if mask is not None:
321
- mask = resize_and_crop(mask, (args.width, args.height))
322
- else:
323
- mask = automasker(
324
- person_image,
325
- cloth_type
326
- )['mask']
327
- mask = mask_processor.blur(mask, blur_factor=9)
328
-
329
- # Inference
330
- result_image = pipeline_flux(
331
- image=person_image,
332
- condition_image=cloth_image,
333
- mask_image=mask,
334
- width=args.width,
335
- height=args.height,
336
- num_inference_steps=num_inference_steps,
337
- guidance_scale=guidance_scale,
338
- generator=generator
339
- ).images[0]
340
-
341
- # Post-processing
342
- masked_person = vis_mask(person_image, mask)
343
-
344
- # Return result based on show type
345
- if show_type == "result only":
346
- return result_image
347
- else:
348
- width, height = person_image.size
349
- if show_type == "input & result":
350
- condition_width = width // 2
351
- conditions = image_grid([person_image, cloth_image], 2, 1)
352
- else:
353
- condition_width = width // 3
354
- conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
355
-
356
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
357
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
358
- new_result_image.paste(conditions, (0, 0))
359
- new_result_image.paste(result_image, (condition_width + 5, 0))
360
- return new_result_image
361
-
362
-
363
- def person_example_fn(image_path):
364
- return image_path
365
-
366
-
367
- HEADER = ""
368
-
369
- def app_gradio():
370
- with gr.Blocks(title="CatVTON") as demo:
371
- gr.Markdown(HEADER)
372
- with gr.Tab("Mask-based & SD1.5"):
373
- with gr.Row():
374
- with gr.Column(scale=1, min_width=350):
375
- with gr.Row():
376
- image_path = gr.Image(
377
- type="filepath",
378
- interactive=True,
379
- visible=False,
380
- )
381
- person_image = gr.ImageEditor(
382
- interactive=True, label="Person Image", type="filepath"
383
- )
384
-
385
- with gr.Row():
386
- with gr.Column(scale=1, min_width=230):
387
- cloth_image = gr.Image(
388
- interactive=True, label="Condition Image", type="filepath"
389
- )
390
- with gr.Column(scale=1, min_width=120):
391
- gr.Markdown(
392
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
393
- )
394
- cloth_type = gr.Radio(
395
- label="Try-On Cloth Type",
396
- choices=["upper", "lower", "overall"],
397
- value="upper",
398
- )
399
-
400
-
401
- submit = gr.Button("Submit")
402
- gr.Markdown(
403
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
404
- )
405
-
406
- gr.Markdown(
407
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
408
- )
409
- with gr.Accordion("Advanced Options", open=False):
410
- num_inference_steps = gr.Slider(
411
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
412
- )
413
- # Guidence Scale
414
- guidance_scale = gr.Slider(
415
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
416
- )
417
- # Random Seed
418
- seed = gr.Slider(
419
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
420
- )
421
- show_type = gr.Radio(
422
- label="Show Type",
423
- choices=["result only", "input & result", "input & mask & result"],
424
- value="input & mask & result",
425
- )
426
-
427
- with gr.Column(scale=2, min_width=500):
428
- result_image = gr.Image(interactive=False, label="Result")
429
- with gr.Row():
430
- # Photo Examples
431
- root_path = "resource/demo/example"
432
- with gr.Column():
433
- men_exm = gr.Examples(
434
- examples=[
435
- os.path.join(root_path, "person", "men", _)
436
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
437
- ],
438
- examples_per_page=4,
439
- inputs=image_path,
440
- label="Person Examples ①",
441
- )
442
- women_exm = gr.Examples(
443
- examples=[
444
- os.path.join(root_path, "person", "women", _)
445
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
446
- ],
447
- examples_per_page=4,
448
- inputs=image_path,
449
- label="Person Examples ②",
450
- )
451
- gr.Markdown(
452
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
453
- )
454
- with gr.Column():
455
- condition_upper_exm = gr.Examples(
456
- examples=[
457
- os.path.join(root_path, "condition", "upper", _)
458
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
459
- ],
460
- examples_per_page=4,
461
- inputs=cloth_image,
462
- label="Condition Upper Examples",
463
- )
464
- condition_overall_exm = gr.Examples(
465
- examples=[
466
- os.path.join(root_path, "condition", "overall", _)
467
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
468
- ],
469
- examples_per_page=4,
470
- inputs=cloth_image,
471
- label="Condition Overall Examples",
472
- )
473
- condition_person_exm = gr.Examples(
474
- examples=[
475
- os.path.join(root_path, "condition", "person", _)
476
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
477
- ],
478
- examples_per_page=4,
479
- inputs=cloth_image,
480
- label="Condition Reference Person Examples",
481
- )
482
- gr.Markdown(
483
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
484
- )
485
-
486
- image_path.change(
487
- person_example_fn, inputs=image_path, outputs=person_image
488
- )
489
-
490
- submit.click(
491
- submit_function,
492
- [
493
- person_image,
494
- cloth_image,
495
- cloth_type,
496
- num_inference_steps,
497
- guidance_scale,
498
- seed,
499
- show_type,
500
- ],
501
- result_image,
502
- )
503
-
504
- demo.queue().launch(share=True, show_error=True)
505
-
506
-
507
- if __name__ == "__main__":
508
- app_gradio()
 
1
+ import argparse
2
+ import os
3
+ os.environ['CUDA_HOME'] = '/usr/local/cuda'
4
+ os.environ['PATH'] = os.environ['PATH'] + ':/usr/local/cuda/bin'
5
+ from datetime import datetime
6
+
7
+ import gradio as gr
8
+ import spaces
9
+ import numpy as np
10
+ import torch
11
+ from diffusers.image_processor import VaeImageProcessor
12
+ from huggingface_hub import snapshot_download
13
+ from PIL import Image
14
+ torch.jit.script = lambda f: f
15
+ from model.cloth_masker import AutoMasker, vis_mask
16
+ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
+ from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
+ from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
+
20
+
21
+ def parse_args():
22
+ parser = argparse.ArgumentParser(description="Simple example of a training script.")
23
+ parser.add_argument(
24
+ "--base_model_path",
25
+ type=str,
26
+ default="booksforcharlie/stable-diffusion-inpainting",
27
+ help=(
28
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
+ ),
30
+ )
31
+ parser.add_argument(
32
+ "--p2p_base_model_path",
33
+ type=str,
34
+ default="timbrooks/instruct-pix2pix",
35
+ help=(
36
+ "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
37
+ ),
38
+ )
39
+ parser.add_argument(
40
+ "--resume_path",
41
+ type=str,
42
+ default="zhengchong/CatVTON",
43
+ help=(
44
+ "The Path to the checkpoint of trained tryon model."
45
+ ),
46
+ )
47
+ parser.add_argument(
48
+ "--output_dir",
49
+ type=str,
50
+ default="resource/demo/output",
51
+ help="The output directory where the model predictions will be written.",
52
+ )
53
+
54
+ parser.add_argument(
55
+ "--width",
56
+ type=int,
57
+ default=768,
58
+ help=(
59
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
60
+ " resolution"
61
+ ),
62
+ )
63
+ parser.add_argument(
64
+ "--height",
65
+ type=int,
66
+ default=1024,
67
+ help=(
68
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
69
+ " resolution"
70
+ ),
71
+ )
72
+ parser.add_argument(
73
+ "--repaint",
74
+ action="store_true",
75
+ help="Whether to repaint the result image with the original background."
76
+ )
77
+ parser.add_argument(
78
+ "--allow_tf32",
79
+ action="store_true",
80
+ default=True,
81
+ help=(
82
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
83
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
84
+ ),
85
+ )
86
+ parser.add_argument(
87
+ "--mixed_precision",
88
+ type=str,
89
+ default="bf16",
90
+ choices=["no", "fp16", "bf16"],
91
+ help=(
92
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
93
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
94
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
95
+ ),
96
+ )
97
+
98
+ args = parser.parse_args()
99
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
100
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
101
+ args.local_rank = env_local_rank
102
+
103
+ return args
104
+
105
+ def image_grid(imgs, rows, cols):
106
+ assert len(imgs) == rows * cols
107
+
108
+ w, h = imgs[0].size
109
+ grid = Image.new("RGB", size=(cols * w, rows * h))
110
+
111
+ for i, img in enumerate(imgs):
112
+ grid.paste(img, box=(i % cols * w, i // cols * h))
113
+ return grid
114
+
115
+
116
+ args = parse_args()
117
+
118
+ # Mask-based CatVTON
119
+ catvton_repo = "zhengchong/CatVTON"
120
+ repo_path = snapshot_download(repo_id=catvton_repo)
121
+ # Pipeline
122
+ pipeline = CatVTONPipeline(
123
+ base_ckpt=args.base_model_path,
124
+ attn_ckpt=repo_path,
125
+ attn_ckpt_version="mix",
126
+ weight_dtype=init_weight_dtype(args.mixed_precision),
127
+ use_tf32=args.allow_tf32,
128
+ device='cuda'
129
+ )
130
+ # AutoMasker
131
+ mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
132
+ automasker = AutoMasker(
133
+ densepose_ckpt=os.path.join(repo_path, "DensePose"),
134
+ schp_ckpt=os.path.join(repo_path, "SCHP"),
135
+ device='cuda',
136
+ )
137
+
138
+
139
+ @spaces.GPU(duration=120)
140
+ def submit_function(
141
+ person_image,
142
+ cloth_image,
143
+ cloth_type,
144
+ num_inference_steps,
145
+ guidance_scale,
146
+ seed,
147
+ show_type
148
+ ):
149
+ person_image, mask = person_image["background"], person_image["layers"][0]
150
+ mask = Image.open(mask).convert("L")
151
+ if len(np.unique(np.array(mask))) == 1:
152
+ mask = None
153
+ else:
154
+ mask = np.array(mask)
155
+ mask[mask > 0] = 255
156
+ mask = Image.fromarray(mask)
157
+
158
+ tmp_folder = args.output_dir
159
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
160
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
161
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
162
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
163
+
164
+ generator = None
165
+ if seed != -1:
166
+ generator = torch.Generator(device='cuda').manual_seed(seed)
167
+
168
+ person_image = Image.open(person_image).convert("RGB")
169
+ cloth_image = Image.open(cloth_image).convert("RGB")
170
+ person_image = resize_and_crop(person_image, (args.width, args.height))
171
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
172
+
173
+ # Process mask
174
+ if mask is not None:
175
+ mask = resize_and_crop(mask, (args.width, args.height))
176
+ else:
177
+ mask = automasker(
178
+ person_image,
179
+ cloth_type
180
+ )['mask']
181
+ mask = mask_processor.blur(mask, blur_factor=9)
182
+
183
+ # Inference
184
+ # try:
185
+ result_image = pipeline(
186
+ image=person_image,
187
+ condition_image=cloth_image,
188
+ mask=mask,
189
+ num_inference_steps=num_inference_steps,
190
+ guidance_scale=guidance_scale,
191
+ generator=generator
192
+ )[0]
193
+ # except Exception as e:
194
+ # raise gr.Error(
195
+ # "An error occurred. Please try again later: {}".format(e)
196
+ # )
197
+
198
+ # Post-process
199
+ masked_person = vis_mask(person_image, mask)
200
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
201
+ save_result_image.save(result_save_path)
202
+ if show_type == "result only":
203
+ return result_image
204
+ else:
205
+ width, height = person_image.size
206
+ if show_type == "input & result":
207
+ condition_width = width // 2
208
+ conditions = image_grid([person_image, cloth_image], 2, 1)
209
+ else:
210
+ condition_width = width // 3
211
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
212
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
213
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
214
+ new_result_image.paste(conditions, (0, 0))
215
+ new_result_image.paste(result_image, (condition_width + 5, 0))
216
+ return new_result_image
217
+
218
+ @spaces.GPU(duration=120)
219
+ def submit_function_p2p(
220
+ person_image,
221
+ cloth_image,
222
+ num_inference_steps,
223
+ guidance_scale,
224
+ seed):
225
+ person_image= person_image["background"]
226
+
227
+ tmp_folder = args.output_dir
228
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
229
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
230
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
231
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
232
+
233
+ generator = None
234
+ if seed != -1:
235
+ generator = torch.Generator(device='cuda').manual_seed(seed)
236
+
237
+ person_image = Image.open(person_image).convert("RGB")
238
+ cloth_image = Image.open(cloth_image).convert("RGB")
239
+ person_image = resize_and_crop(person_image, (args.width, args.height))
240
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
241
+
242
+ # Inference
243
+ try:
244
+ result_image = pipeline_p2p(
245
+ image=person_image,
246
+ condition_image=cloth_image,
247
+ num_inference_steps=num_inference_steps,
248
+ guidance_scale=guidance_scale,
249
+ generator=generator
250
+ )[0]
251
+ except Exception as e:
252
+ raise gr.Error(
253
+ "An error occurred. Please try again later: {}".format(e)
254
+ )
255
+
256
+ # Post-process
257
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
258
+ save_result_image.save(result_save_path)
259
+ return result_image
260
+
261
+ @spaces.GPU(duration=120)
262
+ def submit_function_flux(
263
+ person_image,
264
+ cloth_image,
265
+ cloth_type,
266
+ num_inference_steps,
267
+ guidance_scale,
268
+ seed,
269
+ show_type
270
+ ):
271
+
272
+ # Process image editor input
273
+ person_image, mask = person_image["background"], person_image["layers"][0]
274
+ mask = Image.open(mask).convert("L")
275
+ if len(np.unique(np.array(mask))) == 1:
276
+ mask = None
277
+ else:
278
+ mask = np.array(mask)
279
+ mask[mask > 0] = 255
280
+ mask = Image.fromarray(mask)
281
+
282
+ # Set random seed
283
+ generator = None
284
+ if seed != -1:
285
+ generator = torch.Generator(device='cuda').manual_seed(seed)
286
+
287
+ # Process input images
288
+ person_image = Image.open(person_image).convert("RGB")
289
+ cloth_image = Image.open(cloth_image).convert("RGB")
290
+
291
+ # Adjust image sizes
292
+ person_image = resize_and_crop(person_image, (args.width, args.height))
293
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
294
+
295
+ # Process mask
296
+ if mask is not None:
297
+ mask = resize_and_crop(mask, (args.width, args.height))
298
+ else:
299
+ mask = automasker(
300
+ person_image,
301
+ cloth_type
302
+ )['mask']
303
+ mask = mask_processor.blur(mask, blur_factor=9)
304
+
305
+ # Inference
306
+ result_image = pipeline_flux(
307
+ image=person_image,
308
+ condition_image=cloth_image,
309
+ mask_image=mask,
310
+ width=args.width,
311
+ height=args.height,
312
+ num_inference_steps=num_inference_steps,
313
+ guidance_scale=guidance_scale,
314
+ generator=generator
315
+ ).images[0]
316
+
317
+ # Post-processing
318
+ masked_person = vis_mask(person_image, mask)
319
+
320
+ # Return result based on show type
321
+ if show_type == "result only":
322
+ return result_image
323
+ else:
324
+ width, height = person_image.size
325
+ if show_type == "input & result":
326
+ condition_width = width // 2
327
+ conditions = image_grid([person_image, cloth_image], 2, 1)
328
+ else:
329
+ condition_width = width // 3
330
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
331
+
332
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
333
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
334
+ new_result_image.paste(conditions, (0, 0))
335
+ new_result_image.paste(result_image, (condition_width + 5, 0))
336
+ return new_result_image
337
+
338
+
339
+ def person_example_fn(image_path):
340
+ return image_path
341
+
342
+
343
+ HEADER = ""
344
+
345
+ def app_gradio():
346
+ with gr.Blocks(title="CatVTON") as demo:
347
+ gr.Markdown(HEADER)
348
+ with gr.Tab("Mask-based & SD1.5"):
349
+ with gr.Row():
350
+ with gr.Column(scale=1, min_width=350):
351
+ with gr.Row():
352
+ image_path = gr.Image(
353
+ type="filepath",
354
+ interactive=True,
355
+ visible=False,
356
+ )
357
+ person_image = gr.ImageEditor(
358
+ interactive=True, label="Person Image", type="filepath"
359
+ )
360
+
361
+ with gr.Row():
362
+ with gr.Column(scale=1, min_width=230):
363
+ cloth_image = gr.Image(
364
+ interactive=True, label="Condition Image", type="filepath"
365
+ )
366
+ with gr.Column(scale=1, min_width=120):
367
+ gr.Markdown(
368
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
369
+ )
370
+ cloth_type = gr.Radio(
371
+ label="Try-On Cloth Type",
372
+ choices=["upper", "lower", "overall"],
373
+ value="upper",
374
+ )
375
+
376
+
377
+ submit = gr.Button("Submit")
378
+ gr.Markdown(
379
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
380
+ )
381
+
382
+ gr.Markdown(
383
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
384
+ )
385
+ with gr.Accordion("Advanced Options", open=False):
386
+ num_inference_steps = gr.Slider(
387
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
388
+ )
389
+ # Guidence Scale
390
+ guidance_scale = gr.Slider(
391
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
392
+ )
393
+ # Random Seed
394
+ seed = gr.Slider(
395
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
396
+ )
397
+ show_type = gr.Radio(
398
+ label="Show Type",
399
+ choices=["result only", "input & result", "input & mask & result"],
400
+ value="input & mask & result",
401
+ )
402
+
403
+ with gr.Column(scale=2, min_width=500):
404
+ result_image = gr.Image(interactive=False, label="Result")
405
+ with gr.Row():
406
+ # Photo Examples
407
+ root_path = "resource/demo/example"
408
+ with gr.Column():
409
+ men_exm = gr.Examples(
410
+ examples=[
411
+ os.path.join(root_path, "person", "men", _)
412
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
413
+ ],
414
+ examples_per_page=4,
415
+ inputs=image_path,
416
+ label="Person Examples ①",
417
+ )
418
+ women_exm = gr.Examples(
419
+ examples=[
420
+ os.path.join(root_path, "person", "women", _)
421
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
422
+ ],
423
+ examples_per_page=4,
424
+ inputs=image_path,
425
+ label="Person Examples ②",
426
+ )
427
+ gr.Markdown(
428
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
429
+ )
430
+ with gr.Column():
431
+ condition_upper_exm = gr.Examples(
432
+ examples=[
433
+ os.path.join(root_path, "condition", "upper", _)
434
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
435
+ ],
436
+ examples_per_page=4,
437
+ inputs=cloth_image,
438
+ label="Condition Upper Examples",
439
+ )
440
+ condition_overall_exm = gr.Examples(
441
+ examples=[
442
+ os.path.join(root_path, "condition", "overall", _)
443
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
444
+ ],
445
+ examples_per_page=4,
446
+ inputs=cloth_image,
447
+ label="Condition Overall Examples",
448
+ )
449
+ condition_person_exm = gr.Examples(
450
+ examples=[
451
+ os.path.join(root_path, "condition", "person", _)
452
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
453
+ ],
454
+ examples_per_page=4,
455
+ inputs=cloth_image,
456
+ label="Condition Reference Person Examples",
457
+ )
458
+ gr.Markdown(
459
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
460
+ )
461
+
462
+ image_path.change(
463
+ person_example_fn, inputs=image_path, outputs=person_image
464
+ )
465
+
466
+ submit.click(
467
+ submit_function,
468
+ [
469
+ person_image,
470
+ cloth_image,
471
+ cloth_type,
472
+ num_inference_steps,
473
+ guidance_scale,
474
+ seed,
475
+ show_type,
476
+ ],
477
+ result_image,
478
+ )
479
+
480
+ demo.queue().launch(share=True, show_error=True)
481
+
482
+
483
+ if __name__ == "__main__":
484
+ app_gradio()