nvn04 commited on
Commit
c3e97f1
·
verified ·
1 Parent(s): 6219686

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -356
app.py CHANGED
@@ -17,17 +17,22 @@ from model.pipeline import CatVTONPipeline, CatVTONPix2PixPipeline
17
  from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
 
 
20
 
 
21
  def parse_args():
 
22
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
 
23
  parser.add_argument(
24
  "--base_model_path",
25
  type=str,
26
- default="booksforcharlie/stable-diffusion-inpainting", # Change to a copy repo as runawayml delete original repo
27
  help=(
28
  "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
29
  ),
30
  )
 
31
  parser.add_argument(
32
  "--resume_path",
33
  type=str,
@@ -88,70 +93,97 @@ def parse_args():
88
  )
89
 
90
  args = parser.parse_args()
 
 
 
 
91
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
92
  if env_local_rank != -1 and env_local_rank != args.local_rank:
93
  args.local_rank = env_local_rank
94
 
95
  return args
96
 
 
97
  def image_grid(imgs, rows, cols):
98
- assert len(imgs) == rows * cols
99
 
100
  w, h = imgs[0].size
101
- grid = Image.new("RGB", size=(cols * w, rows * h))
102
 
 
103
  for i, img in enumerate(imgs):
104
  grid.paste(img, box=(i % cols * w, i // cols * h))
105
  return grid
106
 
107
 
108
  args = parse_args()
109
- repo_path = snapshot_download(repo_id=args.resume_path)
110
- # Pipeline
 
 
 
 
111
  pipeline = CatVTONPipeline(
112
- base_ckpt=args.base_model_path,
113
- attn_ckpt=repo_path,
114
  attn_ckpt_version="mix",
115
- weight_dtype=init_weight_dtype(args.mixed_precision),
116
- use_tf32=args.allow_tf32,
117
- device='cuda'
118
  )
119
- # AutoMasker
120
- mask_processor = VaeImageProcessor(vae_scale_factor=8, do_normalize=False, do_binarize=True, do_convert_grayscale=True)
 
 
 
 
 
 
 
 
121
  automasker = AutoMasker(
122
- densepose_ckpt=os.path.join(repo_path, "DensePose"),
123
- schp_ckpt=os.path.join(repo_path, "SCHP"),
124
  device='cuda',
125
  )
126
 
 
 
 
127
  def submit_function(
128
  person_image,
129
  cloth_image,
130
- cloth_type,
131
  num_inference_steps,
132
  guidance_scale,
133
  seed,
134
- show_type
135
  ):
136
- person_image, mask = person_image["background"], person_image["layers"][0]
137
- mask = Image.open(mask).convert("L")
138
- if len(np.unique(np.array(mask))) == 1:
 
 
 
139
  mask = None
140
  else:
141
- mask = np.array(mask)
142
- mask[mask > 0] = 255
143
- mask = Image.fromarray(mask)
144
-
145
- tmp_folder = args.output_dir
146
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
147
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
 
148
  if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
149
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
150
 
 
151
  generator = None
152
- if seed != -1:
153
  generator = torch.Generator(device='cuda').manual_seed(seed)
154
 
 
155
  person_image = Image.open(person_image).convert("RGB")
156
  cloth_image = Image.open(cloth_image).convert("RGB")
157
  person_image = resize_and_crop(person_image, (args.width, args.height))
@@ -159,14 +191,15 @@ def submit_function(
159
 
160
  # Process mask
161
  if mask is not None:
162
- mask = resize_and_crop(mask, (args.width, args.height))
163
  else:
164
  mask = automasker(
165
  person_image,
166
  cloth_type
167
- )['mask']
168
- mask = mask_processor.blur(mask, blur_factor=9)
169
 
 
170
  # Inference
171
  # try:
172
  result_image = pipeline(
@@ -182,90 +215,13 @@ def submit_function(
182
  # "An error occurred. Please try again later: {}".format(e)
183
  # )
184
 
185
- # Post-process
186
- masked_person = vis_mask(person_image, mask)
187
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
 
188
  save_result_image.save(result_save_path)
189
- if show_type == "result only":
190
- return result_image
191
- else:
192
- width, height = person_image.size
193
- if show_type == "input & result":
194
- condition_width = width // 2
195
- conditions = image_grid([person_image, cloth_image], 2, 1)
196
- else:
197
- condition_width = width // 3
198
- conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
199
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
200
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
201
- new_result_image.paste(conditions, (0, 0))
202
- new_result_image.paste(result_image, (condition_width + 5, 0))
203
- return new_result_image
204
-
205
-
206
- @spaces.GPU(duration=120)
207
- def submit_function(
208
- person_image,
209
- cloth_image,
210
- cloth_type,
211
- num_inference_steps,
212
- guidance_scale,
213
- seed,
214
- show_type
215
- ):
216
- person_image, mask = person_image["background"], person_image["layers"][0]
217
- mask = Image.open(mask).convert("L")
218
- if len(np.unique(np.array(mask))) == 1:
219
- mask = None
220
- else:
221
- mask = np.array(mask)
222
- mask[mask > 0] = 255
223
- mask = Image.fromarray(mask)
224
-
225
- tmp_folder = args.output_dir
226
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
227
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
228
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
229
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
230
-
231
- generator = None
232
- if seed != -1:
233
- generator = torch.Generator(device='cuda').manual_seed(seed)
234
-
235
- person_image = Image.open(person_image).convert("RGB")
236
- cloth_image = Image.open(cloth_image).convert("RGB")
237
- person_image = resize_and_crop(person_image, (args.width, args.height))
238
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
239
 
240
- # Process mask
241
- if mask is not None:
242
- mask = resize_and_crop(mask, (args.width, args.height))
243
- else:
244
- mask = automasker(
245
- person_image,
246
- cloth_type
247
- )['mask']
248
- mask = mask_processor.blur(mask, blur_factor=9)
249
-
250
- # Inference
251
- # try:
252
- result_image = pipeline(
253
- image=person_image,
254
- condition_image=cloth_image,
255
- mask=mask,
256
- num_inference_steps=num_inference_steps,
257
- guidance_scale=guidance_scale,
258
- generator=generator
259
- )[0]
260
- # except Exception as e:
261
- # raise gr.Error(
262
- # "An error occurred. Please try again later: {}".format(e)
263
- # )
264
-
265
- # Post-process
266
- masked_person = vis_mask(person_image, mask)
267
- save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
268
- save_result_image.save(result_save_path)
269
  if show_type == "result only":
270
  return result_image
271
  else:
@@ -276,272 +232,165 @@ def submit_function(
276
  else:
277
  condition_width = width // 3
278
  conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
279
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
280
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
281
- new_result_image.paste(conditions, (0, 0))
282
- new_result_image.paste(result_image, (condition_width + 5, 0))
283
- return new_result_image
284
-
285
- @spaces.GPU(duration=120)
286
- def submit_function_p2p(
287
- person_image,
288
- cloth_image,
289
- num_inference_steps,
290
- guidance_scale,
291
- seed):
292
- person_image= person_image["background"]
293
-
294
- tmp_folder = args.output_dir
295
- date_str = datetime.now().strftime("%Y%m%d%H%M%S")
296
- result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
297
- if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
298
- os.makedirs(os.path.join(tmp_folder, date_str[:8]))
299
-
300
- generator = None
301
- if seed != -1:
302
- generator = torch.Generator(device='cuda').manual_seed(seed)
303
-
304
- person_image = Image.open(person_image).convert("RGB")
305
- cloth_image = Image.open(cloth_image).convert("RGB")
306
- person_image = resize_and_crop(person_image, (args.width, args.height))
307
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
308
-
309
- # Inference
310
- try:
311
- result_image = pipeline_p2p(
312
- image=person_image,
313
- condition_image=cloth_image,
314
- num_inference_steps=num_inference_steps,
315
- guidance_scale=guidance_scale,
316
- generator=generator
317
- )[0]
318
- except Exception as e:
319
- raise gr.Error(
320
- "An error occurred. Please try again later: {}".format(e)
321
- )
322
-
323
- # Post-process
324
- save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
325
- save_result_image.save(result_save_path)
326
- return result_image
327
-
328
- @spaces.GPU(duration=120)
329
- def submit_function_flux(
330
- person_image,
331
- cloth_image,
332
- cloth_type,
333
- num_inference_steps,
334
- guidance_scale,
335
- seed,
336
- show_type
337
- ):
338
-
339
- # Process image editor input
340
- person_image, mask = person_image["background"], person_image["layers"][0]
341
- mask = Image.open(mask).convert("L")
342
- if len(np.unique(np.array(mask))) == 1:
343
- mask = None
344
- else:
345
- mask = np.array(mask)
346
- mask[mask > 0] = 255
347
- mask = Image.fromarray(mask)
348
-
349
- # Set random seed
350
- generator = None
351
- if seed != -1:
352
- generator = torch.Generator(device='cuda').manual_seed(seed)
353
-
354
- # Process input images
355
- person_image = Image.open(person_image).convert("RGB")
356
- cloth_image = Image.open(cloth_image).convert("RGB")
357
 
358
- # Adjust image sizes
359
- person_image = resize_and_crop(person_image, (args.width, args.height))
360
- cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
361
-
362
- # Process mask
363
- if mask is not None:
364
- mask = resize_and_crop(mask, (args.width, args.height))
365
- else:
366
- mask = automasker(
367
- person_image,
368
- cloth_type
369
- )['mask']
370
- mask = mask_processor.blur(mask, blur_factor=9)
371
-
372
- # Inference
373
- result_image = pipeline_flux(
374
- image=person_image,
375
- condition_image=cloth_image,
376
- mask_image=mask,
377
- width=args.width,
378
- height=args.height,
379
- num_inference_steps=num_inference_steps,
380
- guidance_scale=guidance_scale,
381
- generator=generator
382
- ).images[0]
383
-
384
- # Post-processing
385
- masked_person = vis_mask(person_image, mask)
386
-
387
- # Return result based on show type
388
- if show_type == "result only":
389
- return result_image
390
- else:
391
- width, height = person_image.size
392
- if show_type == "input & result":
393
- condition_width = width // 2
394
- conditions = image_grid([person_image, cloth_image], 2, 1)
395
- else:
396
- condition_width = width // 3
397
- conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
398
-
399
- conditions = conditions.resize((condition_width, height), Image.NEAREST)
400
- new_result_image = Image.new("RGB", (width + condition_width + 5, height))
401
  new_result_image.paste(conditions, (0, 0))
402
  new_result_image.paste(result_image, (condition_width + 5, 0))
403
- return new_result_image
404
 
405
 
406
  def person_example_fn(image_path):
407
  return image_path
408
 
 
409
  HEADER = ""
410
 
411
  def app_gradio():
412
  with gr.Blocks(title="CatVTON") as demo:
413
  gr.Markdown(HEADER)
414
- with gr.Row():
415
- with gr.Column(scale=1, min_width=350):
416
- with gr.Row():
417
- image_path = gr.Image(
418
- type="filepath",
419
- interactive=True,
420
- visible=False,
421
- )
422
- person_image = gr.ImageEditor(
423
- interactive=True, label="Person Image", type="filepath"
424
- )
425
-
426
- with gr.Row():
427
- with gr.Column(scale=1, min_width=230):
428
- cloth_image = gr.Image(
429
- interactive=True, label="Condition Image", type="filepath"
430
- )
431
- with gr.Column(scale=1, min_width=120):
432
- gr.Markdown(
433
- '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
434
  )
435
- cloth_type = gr.Radio(
436
- label="Try-On Cloth Type",
437
- choices=["upper", "lower", "overall"],
438
- value="upper",
439
  )
440
 
441
-
442
- submit = gr.Button("Submit")
443
- gr.Markdown(
444
- '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
445
- )
446
-
447
- gr.Markdown(
448
- '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
449
- )
450
- with gr.Accordion("Advanced Options", open=False):
451
- num_inference_steps = gr.Slider(
452
- label="Inference Step", minimum=10, maximum=100, step=5, value=50
453
- )
454
- # Guidence Scale
455
- guidance_scale = gr.Slider(
456
- label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
 
 
 
 
457
  )
458
- # Random Seed
459
- seed = gr.Slider(
460
- label="Seed", minimum=-1, maximum=10000, step=1, value=42
 
461
  )
462
- show_type = gr.Radio(
463
- label="Show Type",
464
- choices=["result only", "input & result", "input & mask & result"],
465
- value="input & mask & result",
466
- )
467
-
468
- with gr.Column(scale=2, min_width=500):
469
- result_image = gr.Image(interactive=False, label="Result")
470
- with gr.Row():
471
- # Photo Examples
472
- root_path = "resource/demo/example"
473
- with gr.Column():
474
- men_exm = gr.Examples(
475
- examples=[
476
- os.path.join(root_path, "person", "men", _)
477
- for _ in os.listdir(os.path.join(root_path, "person", "men"))
478
- ],
479
- examples_per_page=4,
480
- inputs=image_path,
481
- label="Person Examples ①",
482
- )
483
- women_exm = gr.Examples(
484
- examples=[
485
- os.path.join(root_path, "person", "women", _)
486
- for _ in os.listdir(os.path.join(root_path, "person", "women"))
487
- ],
488
- examples_per_page=4,
489
- inputs=image_path,
490
- label="Person Examples ②",
491
- )
492
- gr.Markdown(
493
- '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
494
- )
495
- with gr.Column():
496
- condition_upper_exm = gr.Examples(
497
- examples=[
498
- os.path.join(root_path, "condition", "upper", _)
499
- for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
500
- ],
501
- examples_per_page=4,
502
- inputs=cloth_image,
503
- label="Condition Upper Examples",
504
  )
505
- condition_overall_exm = gr.Examples(
506
- examples=[
507
- os.path.join(root_path, "condition", "overall", _)
508
- for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
509
- ],
510
- examples_per_page=4,
511
- inputs=cloth_image,
512
- label="Condition Overall Examples",
513
  )
514
- condition_person_exm = gr.Examples(
515
- examples=[
516
- os.path.join(root_path, "condition", "person", _)
517
- for _ in os.listdir(os.path.join(root_path, "condition", "person"))
518
- ],
519
- examples_per_page=4,
520
- inputs=cloth_image,
521
- label="Condition Reference Person Examples",
522
  )
523
- gr.Markdown(
524
- '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
 
 
525
  )
526
 
527
- image_path.change(
528
- person_example_fn, inputs=image_path, outputs=person_image
529
- )
530
-
531
- submit.click(
532
- submit_function,
533
- [
534
- person_image,
535
- cloth_image,
536
- cloth_type,
537
- num_inference_steps,
538
- guidance_scale,
539
- seed,
540
- show_type,
541
- ],
542
- result_image,
543
- )
544
- demo.queue().launch(share=True, show_error=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
 
547
  if __name__ == "__main__":
 
17
  from model.flux.pipeline_flux_tryon import FluxTryOnPipeline
18
  from utils import init_weight_dtype, resize_and_crop, resize_and_padding
19
 
20
+ access_token = os.getenv('HF_ACCESS_TOKEN')
21
 
22
+ # dùng để phân tích các tham số từ dòng lệnh và trả về cấu hình cài đặt cho chương trình
23
  def parse_args():
24
+ # Khởi tạo đối tượng để quản lý các tham số dòng lệnh.
25
  parser = argparse.ArgumentParser(description="Simple example of a training script.")
26
+
27
  parser.add_argument(
28
  "--base_model_path",
29
  type=str,
30
+ default="booksforcharlie/stable-diffusion-inpainting",
31
  help=(
32
  "The path to the base model to use for evaluation. This can be a local path or a model identifier from the Model Hub."
33
  ),
34
  )
35
+
36
  parser.add_argument(
37
  "--resume_path",
38
  type=str,
 
93
  )
94
 
95
  args = parser.parse_args()
96
+
97
+ # Xử lý tham số:
98
+ # Đảm bảo rằng local_rank (chỉ số GPU cục bộ khi chạy phân tán) được đồng bộ từ biến môi trường
99
+ # Khi chạy các tác vụ huấn luyện phân tán, hệ thống cần biết chỉ số GPU cục bộ để phân bổ tài nguyên.
100
  env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
101
  if env_local_rank != -1 and env_local_rank != args.local_rank:
102
  args.local_rank = env_local_rank
103
 
104
  return args
105
 
106
+ # Hàm image_grid tạo một lưới ảnh (grid) từ danh sách các ảnh đầu vào, với số hàng (rows) và số cột (cols) được chỉ định.
107
  def image_grid(imgs, rows, cols):
108
+ assert len(imgs) == rows * cols # Kiểm tra số lượng ảnh
109
 
110
  w, h = imgs[0].size
111
+ grid = Image.new("RGB", size=(cols * w, rows * h)) # Tạo ảnh trống làm lưới
112
 
113
+ #Duyệt qua các ảnh và ghép vào lưới
114
  for i, img in enumerate(imgs):
115
  grid.paste(img, box=(i % cols * w, i // cols * h))
116
  return grid
117
 
118
 
119
  args = parse_args()
120
+
121
+ # Mask-based CatVTON
122
+ catvton_repo = "zhengchong/CatVTON"
123
+ repo_path = snapshot_download(repo_id=catvton_repo) # snapshot_download: Hàm này tải toàn bộ dữ liệu mô hình từ kho lưu trữ trên Hugging Face và lưu về máy cục bộ.
124
+
125
+ # Pipeline thực hiện Virtual Try on (dùng mask)
126
  pipeline = CatVTONPipeline(
127
+ base_ckpt=args.base_model_path, # Checkpoint của mô hình cơ sở (dùng để tạo nền tảng cho pipeline).
128
+ attn_ckpt=repo_path, # Checkpoint chứa các tham số của attention module, được tải từ repo_path.
129
  attn_ckpt_version="mix",
130
+ weight_dtype=init_weight_dtype(args.mixed_precision), # Kiểu dữ liệu của trọng số mô hình. Được thiết lập bởi hàm init_weight_dtype, có thể là fp16 hoặc bf16 tùy thuộc vào GPU và cấu hình.
131
+ use_tf32=args.allow_tf32, # Cho phép sử dụng TensorFloat32 trên GPU Ampere (như A100) để tăng tốc.
132
+ device='cuda' # Thiết bị chạy mô hình (ở đây là cuda, tức GPU).
133
  )
134
+
135
+ # AutoMasker Part
136
+ # VaeImageProcessor: Bộ xử lý hình ảnh được thiết kế để làm việc với các mô hình dựa trên VAE (Variational Autoencoder).
137
+ mask_processor = VaeImageProcessor(
138
+ vae_scale_factor=8, # Tỉ lệ nén hình ảnh khi xử lý bằng VAE. Ảnh sẽ được giảm kích thước theo tỉ lệ 1/8.
139
+ do_normalize=False, # Không thực hiện chuẩn hóa giá trị pixel (ví dụ: chuyển đổi giá trị về khoảng [0, 1]).
140
+ do_binarize=True, # Chuyển đổi hình ảnh thành nhị phân (chỉ chứa 2 giá trị: 0 hoặc 255). Quan trọng để tạo mặt nạ rõ ràng.
141
+ do_convert_grayscale=True
142
+ )
143
+ # AutoMasker: Công cụ tự động tạo mặt nạ dựa trên các mô hình dự đoán hình dạng cơ thể người và phân đoạn quần áo.
144
  automasker = AutoMasker(
145
+ densepose_ckpt=os.path.join(repo_path, "DensePose"), # DensePose: Mô hình dự đoán vị trí 3D của cơ thể từ ảnh 2D.
146
+ schp_ckpt=os.path.join(repo_path, "SCHP"), # SCHP: Mô hình phân đoạn chi tiết cơ thể người (ví dụ: tách tóc, quần áo, da, v.v.).
147
  device='cuda',
148
  )
149
 
150
+ # Hàm này nhận dữ liệu đầu vào (ảnh người, ảnh quần áo, các tham số) và thực hiện các bước xử lý để trả về ảnh kết quả.
151
+ @spaces.GPU(duration=120) # Gán GPU để thực hiện hàm submit_function, với thời gian tối đa là 120 giây.
152
+ # Định nghĩa hàm nhận vào các tham số sau
153
  def submit_function(
154
  person_image,
155
  cloth_image,
156
+ cloth_type, # upper, lower, hoặc overall
157
  num_inference_steps,
158
  guidance_scale,
159
  seed,
160
+ show_type # Kiểu hiển thị kết quả (chỉ kết quả, kết hợp ảnh gốc và kết quả, hoặc hiển thị cả mặt nạ).
161
  ):
162
+ # Xử mặt nạ (mask)
163
+ person_image,
164
+ mask = person_image["background"], # Lấy ảnh người từ lớp nền.
165
+ person_image["layers"][0] # Lấy mặt nạ do người dùng vẽ (nếu có).
166
+ mask = Image.open(mask).convert("L") # Chuyển mặt nạ thành ảnh thang độ xám
167
+ if len(np.unique(np.array(mask))) == 1: # Nếu mặt nạ chỉ chứa một giá trị (ví dụ: toàn đen hoặc toàn trắng), thì không sử dụng mặt nạ (mask = None).
168
  mask = None
169
  else:
170
+ mask = np.array(mask) # Chuyển mặt nạ thành mảng numpy.
171
+ mask[mask > 0] = 255 # Các pixel có giá trị lớn hơn 0 được chuyển thành 255 (trắng).
172
+ mask = Image.fromarray(mask) # Chuyển mảng trở lại thành ảnh.
173
+
174
+ # Xử lý đường dẫn lưu trữ kết quả
175
+ tmp_folder = args.output_dir # Thư mục tạm thời lưu kết quả.
176
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S") # Chuỗi ngày giờ hiện tại (ví dụ: 20250108).
177
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png") # Đường dẫn đầy đủ để lưu ảnh kết quả.
178
  if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
179
+ os.makedirs(os.path.join(tmp_folder, date_str[:8])) # Tạo thư mục lưu trữ nếu chưa tồn tại.
180
 
181
+ # Xử lý seed ngẫu nhiên
182
  generator = None
183
+ if seed != -1: # Nếu seed được cung cấp, mô hình sẽ sử dụng giá trị này để sinh dữ liệu (giữ tính ngẫu nhiên nhưng tái tạo được).
184
  generator = torch.Generator(device='cuda').manual_seed(seed)
185
 
186
+ # Chuẩn hóa ảnh đầu vào
187
  person_image = Image.open(person_image).convert("RGB")
188
  cloth_image = Image.open(cloth_image).convert("RGB")
189
  person_image = resize_and_crop(person_image, (args.width, args.height))
 
191
 
192
  # Process mask
193
  if mask is not None:
194
+ mask = resize_and_crop(mask, (args.width, args.height)) # Nếu mặt nạ được cung cấp, thay đổi kích thước cho phù hợp.
195
  else:
196
  mask = automasker(
197
  person_image,
198
  cloth_type
199
+ )['mask'] # Nếu không, tạo mặt nạ tự động bằng automasker, dựa trên loại quần áo (cloth_type).
200
+ mask = mask_processor.blur(mask, blur_factor=9) # Làm mờ mặt nạ (blur) để giảm bớt các cạnh sắc
201
 
202
+ # Suy luận mô hình: gán các tham số vô hàm tính toán, trả lại result là hình ảnh
203
  # Inference
204
  # try:
205
  result_image = pipeline(
 
215
  # "An error occurred. Please try again later: {}".format(e)
216
  # )
217
 
218
+ # Post-process - Xử lý hậu kỳ
219
+ # Tạo ảnh kết quả lưới
220
+ masked_person = vis_mask(person_image, mask) # Hiển thị ảnh người với mặt nạ được áp dụng.
221
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4) # Tạo một ảnh lưới chứa
222
  save_result_image.save(result_save_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
+ # Điều chỉnh hiển thị kết quả
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  if show_type == "result only":
226
  return result_image
227
  else:
 
232
  else:
233
  condition_width = width // 3
234
  conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
235
 
236
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
237
+ # conditions: Ảnh ghép ban đầu, được tạo từ các ảnh như ảnh người gốc, ảnh quần áo, và ảnh mặt nạ (tùy chọn).
238
+ # Tham số Image.NEAREST: Đây là phương pháp nội suy (interpolation) gần nhất, dùng để thay đổi kích thước ảnh mà không làm mờ hay mất chi tiết.
239
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height)) # Image.new: Tạo một ảnh trống mới
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
240
  new_result_image.paste(conditions, (0, 0))
241
  new_result_image.paste(result_image, (condition_width + 5, 0))
242
+ return new_result_image
243
 
244
 
245
  def person_example_fn(image_path):
246
  return image_path
247
 
248
+
249
  HEADER = ""
250
 
251
  def app_gradio():
252
  with gr.Blocks(title="CatVTON") as demo:
253
  gr.Markdown(HEADER)
254
+ with gr.Tab("Mask-based"):
255
+ with gr.Row():
256
+ with gr.Column(scale=1, min_width=350):
257
+ # Ảnh model (người)
258
+ with gr.Row():
259
+ image_path = gr.Image(
260
+ type="filepath",
261
+ interactive=True,
262
+ visible=False,
 
 
 
 
 
 
 
 
 
 
 
263
  )
264
+ person_image = gr.ImageEditor(
265
+ interactive=True, label="Person Image", type="filepath"
 
 
266
  )
267
 
268
+ # Ảnh quần áo
269
+ with gr.Row():
270
+ with gr.Column(scale=1, min_width=230):
271
+ cloth_image = gr.Image(
272
+ interactive=True, label="Condition Image", type="filepath"
273
+ )
274
+ with gr.Column(scale=1, min_width=120):
275
+ gr.Markdown(
276
+ '<span style="color: #808080; font-size: small;">Two ways to provide Mask:<br>1. Upload the person image and use the `🖌️` above to draw the Mask (higher priority)<br>2. Select the `Try-On Cloth Type` to generate automatically </span>'
277
+ )
278
+ cloth_type = gr.Radio(
279
+ label="Try-On Cloth Type",
280
+ choices=["upper", "lower", "overall"],
281
+ value="upper",
282
+ )
283
+
284
+ # Submit button - Run
285
+ submit = gr.Button("Submit")
286
+ gr.Markdown(
287
+ '<center><span style="color: #FF0000">!!! Click only Once, Wait for Delay !!!</span></center>'
288
  )
289
+
290
+ # Advance setting
291
+ gr.Markdown(
292
+ '<span style="color: #808080; font-size: small;">Advanced options can adjust details:<br>1. `Inference Step` may enhance details;<br>2. `CFG` is highly correlated with saturation;<br>3. `Random seed` may improve pseudo-shadow.</span>'
293
  )
294
+ with gr.Accordion("Advanced Options", open=False):
295
+ num_inference_steps = gr.Slider(
296
+ label="Inference Step", minimum=10, maximum=100, step=5, value=50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
297
  )
298
+ # Guidence Scale
299
+ guidance_scale = gr.Slider(
300
+ label="CFG Strenth", minimum=0.0, maximum=7.5, step=0.5, value=2.5
 
 
 
 
 
301
  )
302
+ # Random Seed
303
+ seed = gr.Slider(
304
+ label="Seed", minimum=-1, maximum=10000, step=1, value=42
 
 
 
 
 
305
  )
306
+ show_type = gr.Radio(
307
+ label="Show Type",
308
+ choices=["result only", "input & result", "input & mask & result"],
309
+ value="input & mask & result",
310
  )
311
 
312
+
313
+ with gr.Column(scale=2, min_width=500):
314
+ # Result image
315
+ result_image = gr.Image(interactive=False, label="Result")
316
+ with gr.Row():
317
+ # Photo Examples
318
+ root_path = "resource/demo/example"
319
+ with gr.Column():
320
+ men_exm = gr.Examples(
321
+ examples=[
322
+ os.path.join(root_path, "person", "men", _)
323
+ for _ in os.listdir(os.path.join(root_path, "person", "men"))
324
+ ],
325
+ examples_per_page=4,
326
+ inputs=image_path,
327
+ label="Person Examples ①",
328
+ )
329
+ women_exm = gr.Examples(
330
+ examples=[
331
+ os.path.join(root_path, "person", "women", _)
332
+ for _ in os.listdir(os.path.join(root_path, "person", "women"))
333
+ ],
334
+ examples_per_page=4,
335
+ inputs=image_path,
336
+ label="Person Examples ②",
337
+ )
338
+ gr.Markdown(
339
+ '<span style="color: #808080; font-size: small;">*Person examples come from the demos of <a href="https://huggingface.co/spaces/levihsu/OOTDiffusion">OOTDiffusion</a> and <a href="https://www.outfitanyone.org">OutfitAnyone</a>. </span>'
340
+ )
341
+ with gr.Column():
342
+ condition_upper_exm = gr.Examples(
343
+ examples=[
344
+ os.path.join(root_path, "condition", "upper", _)
345
+ for _ in os.listdir(os.path.join(root_path, "condition", "upper"))
346
+ ],
347
+ examples_per_page=4,
348
+ inputs=cloth_image,
349
+ label="Condition Upper Examples",
350
+ )
351
+ condition_overall_exm = gr.Examples(
352
+ examples=[
353
+ os.path.join(root_path, "condition", "overall", _)
354
+ for _ in os.listdir(os.path.join(root_path, "condition", "overall"))
355
+ ],
356
+ examples_per_page=4,
357
+ inputs=cloth_image,
358
+ label="Condition Overall Examples",
359
+ )
360
+ condition_person_exm = gr.Examples(
361
+ examples=[
362
+ os.path.join(root_path, "condition", "person", _)
363
+ for _ in os.listdir(os.path.join(root_path, "condition", "person"))
364
+ ],
365
+ examples_per_page=4,
366
+ inputs=cloth_image,
367
+ label="Condition Reference Person Examples",
368
+ )
369
+ gr.Markdown(
370
+ '<span style="color: #808080; font-size: small;">*Condition examples come from the Internet. </span>'
371
+ )
372
+
373
+ image_path.change(
374
+ person_example_fn, inputs=image_path, outputs=person_image
375
+ )
376
+
377
+ # Function khi ấn nút submit
378
+ submit.click(
379
+ submit_function,
380
+ [
381
+ person_image,
382
+ cloth_image,
383
+ cloth_type,
384
+ num_inference_steps,
385
+ guidance_scale,
386
+ seed,
387
+ show_type,
388
+ ],
389
+ result_image,
390
+ )
391
+
392
+ # demo.queue().launch(share=True, show_error=True)
393
+ demo.queue().launch()
394
 
395
 
396
  if __name__ == "__main__":