nvn04 commited on
Commit
6219686
·
verified ·
1 Parent(s): 6f5a450

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -31
app.py CHANGED
@@ -203,40 +203,210 @@ def submit_function(
203
  return new_result_image
204
 
205
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  def person_example_fn(image_path):
207
  return image_path
208
 
209
- HEADER = """
210
- <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
211
- <div style="display: flex; justify-content: center; align-items: center;">
212
- <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
213
- <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
214
- </a>
215
- <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
216
- <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
217
- </a>
218
- <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
219
- <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
220
- </a>
221
- <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
222
- <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
223
- </a>
224
- <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
225
- <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
226
- </a>
227
- <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
228
- <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
229
- </a>
230
- <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
231
- <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
232
- </a>
233
- </div>
234
- <br>
235
- · This demo and our weights are only for <span>Non-commercial Use</span>. <br>
236
- · You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
237
- · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
238
- · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
239
- """
240
 
241
  def app_gradio():
242
  with gr.Blocks(title="CatVTON") as demo:
 
203
  return new_result_image
204
 
205
 
206
+ @spaces.GPU(duration=120)
207
+ def submit_function(
208
+ person_image,
209
+ cloth_image,
210
+ cloth_type,
211
+ num_inference_steps,
212
+ guidance_scale,
213
+ seed,
214
+ show_type
215
+ ):
216
+ person_image, mask = person_image["background"], person_image["layers"][0]
217
+ mask = Image.open(mask).convert("L")
218
+ if len(np.unique(np.array(mask))) == 1:
219
+ mask = None
220
+ else:
221
+ mask = np.array(mask)
222
+ mask[mask > 0] = 255
223
+ mask = Image.fromarray(mask)
224
+
225
+ tmp_folder = args.output_dir
226
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
227
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
228
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
229
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
230
+
231
+ generator = None
232
+ if seed != -1:
233
+ generator = torch.Generator(device='cuda').manual_seed(seed)
234
+
235
+ person_image = Image.open(person_image).convert("RGB")
236
+ cloth_image = Image.open(cloth_image).convert("RGB")
237
+ person_image = resize_and_crop(person_image, (args.width, args.height))
238
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
239
+
240
+ # Process mask
241
+ if mask is not None:
242
+ mask = resize_and_crop(mask, (args.width, args.height))
243
+ else:
244
+ mask = automasker(
245
+ person_image,
246
+ cloth_type
247
+ )['mask']
248
+ mask = mask_processor.blur(mask, blur_factor=9)
249
+
250
+ # Inference
251
+ # try:
252
+ result_image = pipeline(
253
+ image=person_image,
254
+ condition_image=cloth_image,
255
+ mask=mask,
256
+ num_inference_steps=num_inference_steps,
257
+ guidance_scale=guidance_scale,
258
+ generator=generator
259
+ )[0]
260
+ # except Exception as e:
261
+ # raise gr.Error(
262
+ # "An error occurred. Please try again later: {}".format(e)
263
+ # )
264
+
265
+ # Post-process
266
+ masked_person = vis_mask(person_image, mask)
267
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
268
+ save_result_image.save(result_save_path)
269
+ if show_type == "result only":
270
+ return result_image
271
+ else:
272
+ width, height = person_image.size
273
+ if show_type == "input & result":
274
+ condition_width = width // 2
275
+ conditions = image_grid([person_image, cloth_image], 2, 1)
276
+ else:
277
+ condition_width = width // 3
278
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
279
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
280
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
281
+ new_result_image.paste(conditions, (0, 0))
282
+ new_result_image.paste(result_image, (condition_width + 5, 0))
283
+ return new_result_image
284
+
285
+ @spaces.GPU(duration=120)
286
+ def submit_function_p2p(
287
+ person_image,
288
+ cloth_image,
289
+ num_inference_steps,
290
+ guidance_scale,
291
+ seed):
292
+ person_image= person_image["background"]
293
+
294
+ tmp_folder = args.output_dir
295
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
296
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
297
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
298
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
299
+
300
+ generator = None
301
+ if seed != -1:
302
+ generator = torch.Generator(device='cuda').manual_seed(seed)
303
+
304
+ person_image = Image.open(person_image).convert("RGB")
305
+ cloth_image = Image.open(cloth_image).convert("RGB")
306
+ person_image = resize_and_crop(person_image, (args.width, args.height))
307
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
308
+
309
+ # Inference
310
+ try:
311
+ result_image = pipeline_p2p(
312
+ image=person_image,
313
+ condition_image=cloth_image,
314
+ num_inference_steps=num_inference_steps,
315
+ guidance_scale=guidance_scale,
316
+ generator=generator
317
+ )[0]
318
+ except Exception as e:
319
+ raise gr.Error(
320
+ "An error occurred. Please try again later: {}".format(e)
321
+ )
322
+
323
+ # Post-process
324
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
325
+ save_result_image.save(result_save_path)
326
+ return result_image
327
+
328
+ @spaces.GPU(duration=120)
329
+ def submit_function_flux(
330
+ person_image,
331
+ cloth_image,
332
+ cloth_type,
333
+ num_inference_steps,
334
+ guidance_scale,
335
+ seed,
336
+ show_type
337
+ ):
338
+
339
+ # Process image editor input
340
+ person_image, mask = person_image["background"], person_image["layers"][0]
341
+ mask = Image.open(mask).convert("L")
342
+ if len(np.unique(np.array(mask))) == 1:
343
+ mask = None
344
+ else:
345
+ mask = np.array(mask)
346
+ mask[mask > 0] = 255
347
+ mask = Image.fromarray(mask)
348
+
349
+ # Set random seed
350
+ generator = None
351
+ if seed != -1:
352
+ generator = torch.Generator(device='cuda').manual_seed(seed)
353
+
354
+ # Process input images
355
+ person_image = Image.open(person_image).convert("RGB")
356
+ cloth_image = Image.open(cloth_image).convert("RGB")
357
+
358
+ # Adjust image sizes
359
+ person_image = resize_and_crop(person_image, (args.width, args.height))
360
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
361
+
362
+ # Process mask
363
+ if mask is not None:
364
+ mask = resize_and_crop(mask, (args.width, args.height))
365
+ else:
366
+ mask = automasker(
367
+ person_image,
368
+ cloth_type
369
+ )['mask']
370
+ mask = mask_processor.blur(mask, blur_factor=9)
371
+
372
+ # Inference
373
+ result_image = pipeline_flux(
374
+ image=person_image,
375
+ condition_image=cloth_image,
376
+ mask_image=mask,
377
+ width=args.width,
378
+ height=args.height,
379
+ num_inference_steps=num_inference_steps,
380
+ guidance_scale=guidance_scale,
381
+ generator=generator
382
+ ).images[0]
383
+
384
+ # Post-processing
385
+ masked_person = vis_mask(person_image, mask)
386
+
387
+ # Return result based on show type
388
+ if show_type == "result only":
389
+ return result_image
390
+ else:
391
+ width, height = person_image.size
392
+ if show_type == "input & result":
393
+ condition_width = width // 2
394
+ conditions = image_grid([person_image, cloth_image], 2, 1)
395
+ else:
396
+ condition_width = width // 3
397
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
398
+
399
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
400
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
401
+ new_result_image.paste(conditions, (0, 0))
402
+ new_result_image.paste(result_image, (condition_width + 5, 0))
403
+ return new_result_image
404
+
405
+
406
  def person_example_fn(image_path):
407
  return image_path
408
 
409
+ HEADER = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
 
411
  def app_gradio():
412
  with gr.Blocks(title="CatVTON") as demo: