nvn04 commited on
Commit
5818481
·
verified ·
1 Parent(s): 3392168

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +201 -31
app.py CHANGED
@@ -198,40 +198,210 @@ def submit_function(
198
  return new_result_image
199
 
200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
  def person_example_fn(image_path):
202
  return image_path
203
 
204
- HEADER = """
205
- <h1 style="text-align: center;"> 🐈 CatVTON: Concatenation Is All You Need for Virtual Try-On with Diffusion Models </h1>
206
- <div style="display: flex; justify-content: center; align-items: center;">
207
- <a href="http://arxiv.org/abs/2407.15886" style="margin: 0 2px;">
208
- <img src='https://img.shields.io/badge/arXiv-2407.15886-red?style=flat&logo=arXiv&logoColor=red' alt='arxiv'>
209
- </a>
210
- <a href='https://huggingface.co/zhengchong/CatVTON' style="margin: 0 2px;">
211
- <img src='https://img.shields.io/badge/Hugging Face-ckpts-orange?style=flat&logo=HuggingFace&logoColor=orange' alt='huggingface'>
212
- </a>
213
- <a href="https://github.com/Zheng-Chong/CatVTON" style="margin: 0 2px;">
214
- <img src='https://img.shields.io/badge/GitHub-Repo-blue?style=flat&logo=GitHub' alt='GitHub'>
215
- </a>
216
- <a href="http://120.76.142.206:8888" style="margin: 0 2px;">
217
- <img src='https://img.shields.io/badge/Demo-Gradio-gold?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
218
- </a>
219
- <a href="https://huggingface.co/spaces/zhengchong/CatVTON" style="margin: 0 2px;">
220
- <img src='https://img.shields.io/badge/Space-ZeroGPU-orange?style=flat&logo=Gradio&logoColor=red' alt='Demo'>
221
- </a>
222
- <a href='https://zheng-chong.github.io/CatVTON/' style="margin: 0 2px;">
223
- <img src='https://img.shields.io/badge/Webpage-Project-silver?style=flat&logo=&logoColor=orange' alt='webpage'>
224
- </a>
225
- <a href="https://github.com/Zheng-Chong/CatVTON/LICENCE" style="margin: 0 2px;">
226
- <img src='https://img.shields.io/badge/License-CC BY--NC--SA--4.0-lightgreen?style=flat&logo=Lisence' alt='License'>
227
- </a>
228
- </div>
229
- <br>
230
- · This demo and our weights are only for <span>Non-commercial Use</span>. <br>
231
- · You can try CatVTON in our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a> or our <a href="http://120.76.142.206:8888">online demo</a> (run on 3090). <br>
232
- · Thanks to <a href="https://huggingface.co/zero-gpu-explorers">ZeroGPU</a> for providing A100 for our <a href="https://huggingface.co/spaces/zhengchong/CatVTON">HuggingFace Space</a>. <br>
233
- · SafetyChecker is set to filter NSFW content, but it may block normal results too. Please adjust the <span>`seed`</span> for normal outcomes.<br>
234
- """
235
 
236
  def app_gradio():
237
  with gr.Blocks(title="CatVTON") as demo:
 
198
  return new_result_image
199
 
200
 
201
+ @spaces.GPU(duration=120)
202
+ def submit_function(
203
+ person_image,
204
+ cloth_image,
205
+ cloth_type,
206
+ num_inference_steps,
207
+ guidance_scale,
208
+ seed,
209
+ show_type
210
+ ):
211
+ person_image, mask = person_image["background"], person_image["layers"][0]
212
+ mask = Image.open(mask).convert("L")
213
+ if len(np.unique(np.array(mask))) == 1:
214
+ mask = None
215
+ else:
216
+ mask = np.array(mask)
217
+ mask[mask > 0] = 255
218
+ mask = Image.fromarray(mask)
219
+
220
+ tmp_folder = args.output_dir
221
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
222
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
223
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
224
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
225
+
226
+ generator = None
227
+ if seed != -1:
228
+ generator = torch.Generator(device='cuda').manual_seed(seed)
229
+
230
+ person_image = Image.open(person_image).convert("RGB")
231
+ cloth_image = Image.open(cloth_image).convert("RGB")
232
+ person_image = resize_and_crop(person_image, (args.width, args.height))
233
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
234
+
235
+ # Process mask
236
+ if mask is not None:
237
+ mask = resize_and_crop(mask, (args.width, args.height))
238
+ else:
239
+ mask = automasker(
240
+ person_image,
241
+ cloth_type
242
+ )['mask']
243
+ mask = mask_processor.blur(mask, blur_factor=9)
244
+
245
+ # Inference
246
+ # try:
247
+ result_image = pipeline(
248
+ image=person_image,
249
+ condition_image=cloth_image,
250
+ mask=mask,
251
+ num_inference_steps=num_inference_steps,
252
+ guidance_scale=guidance_scale,
253
+ generator=generator
254
+ )[0]
255
+ # except Exception as e:
256
+ # raise gr.Error(
257
+ # "An error occurred. Please try again later: {}".format(e)
258
+ # )
259
+
260
+ # Post-process
261
+ masked_person = vis_mask(person_image, mask)
262
+ save_result_image = image_grid([person_image, masked_person, cloth_image, result_image], 1, 4)
263
+ save_result_image.save(result_save_path)
264
+ if show_type == "result only":
265
+ return result_image
266
+ else:
267
+ width, height = person_image.size
268
+ if show_type == "input & result":
269
+ condition_width = width // 2
270
+ conditions = image_grid([person_image, cloth_image], 2, 1)
271
+ else:
272
+ condition_width = width // 3
273
+ conditions = image_grid([person_image, masked_person , cloth_image], 3, 1)
274
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
275
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
276
+ new_result_image.paste(conditions, (0, 0))
277
+ new_result_image.paste(result_image, (condition_width + 5, 0))
278
+ return new_result_image
279
+
280
+ @spaces.GPU(duration=120)
281
+ def submit_function_p2p(
282
+ person_image,
283
+ cloth_image,
284
+ num_inference_steps,
285
+ guidance_scale,
286
+ seed):
287
+ person_image= person_image["background"]
288
+
289
+ tmp_folder = args.output_dir
290
+ date_str = datetime.now().strftime("%Y%m%d%H%M%S")
291
+ result_save_path = os.path.join(tmp_folder, date_str[:8], date_str[8:] + ".png")
292
+ if not os.path.exists(os.path.join(tmp_folder, date_str[:8])):
293
+ os.makedirs(os.path.join(tmp_folder, date_str[:8]))
294
+
295
+ generator = None
296
+ if seed != -1:
297
+ generator = torch.Generator(device='cuda').manual_seed(seed)
298
+
299
+ person_image = Image.open(person_image).convert("RGB")
300
+ cloth_image = Image.open(cloth_image).convert("RGB")
301
+ person_image = resize_and_crop(person_image, (args.width, args.height))
302
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
303
+
304
+ # Inference
305
+ try:
306
+ result_image = pipeline_p2p(
307
+ image=person_image,
308
+ condition_image=cloth_image,
309
+ num_inference_steps=num_inference_steps,
310
+ guidance_scale=guidance_scale,
311
+ generator=generator
312
+ )[0]
313
+ except Exception as e:
314
+ raise gr.Error(
315
+ "An error occurred. Please try again later: {}".format(e)
316
+ )
317
+
318
+ # Post-process
319
+ save_result_image = image_grid([person_image, cloth_image, result_image], 1, 3)
320
+ save_result_image.save(result_save_path)
321
+ return result_image
322
+
323
+ @spaces.GPU(duration=120)
324
+ def submit_function_flux(
325
+ person_image,
326
+ cloth_image,
327
+ cloth_type,
328
+ num_inference_steps,
329
+ guidance_scale,
330
+ seed,
331
+ show_type
332
+ ):
333
+
334
+ # Process image editor input
335
+ person_image, mask = person_image["background"], person_image["layers"][0]
336
+ mask = Image.open(mask).convert("L")
337
+ if len(np.unique(np.array(mask))) == 1:
338
+ mask = None
339
+ else:
340
+ mask = np.array(mask)
341
+ mask[mask > 0] = 255
342
+ mask = Image.fromarray(mask)
343
+
344
+ # Set random seed
345
+ generator = None
346
+ if seed != -1:
347
+ generator = torch.Generator(device='cuda').manual_seed(seed)
348
+
349
+ # Process input images
350
+ person_image = Image.open(person_image).convert("RGB")
351
+ cloth_image = Image.open(cloth_image).convert("RGB")
352
+
353
+ # Adjust image sizes
354
+ person_image = resize_and_crop(person_image, (args.width, args.height))
355
+ cloth_image = resize_and_padding(cloth_image, (args.width, args.height))
356
+
357
+ # Process mask
358
+ if mask is not None:
359
+ mask = resize_and_crop(mask, (args.width, args.height))
360
+ else:
361
+ mask = automasker(
362
+ person_image,
363
+ cloth_type
364
+ )['mask']
365
+ mask = mask_processor.blur(mask, blur_factor=9)
366
+
367
+ # Inference
368
+ result_image = pipeline_flux(
369
+ image=person_image,
370
+ condition_image=cloth_image,
371
+ mask_image=mask,
372
+ width=args.width,
373
+ height=args.height,
374
+ num_inference_steps=num_inference_steps,
375
+ guidance_scale=guidance_scale,
376
+ generator=generator
377
+ ).images[0]
378
+
379
+ # Post-processing
380
+ masked_person = vis_mask(person_image, mask)
381
+
382
+ # Return result based on show type
383
+ if show_type == "result only":
384
+ return result_image
385
+ else:
386
+ width, height = person_image.size
387
+ if show_type == "input & result":
388
+ condition_width = width // 2
389
+ conditions = image_grid([person_image, cloth_image], 2, 1)
390
+ else:
391
+ condition_width = width // 3
392
+ conditions = image_grid([person_image, masked_person, cloth_image], 3, 1)
393
+
394
+ conditions = conditions.resize((condition_width, height), Image.NEAREST)
395
+ new_result_image = Image.new("RGB", (width + condition_width + 5, height))
396
+ new_result_image.paste(conditions, (0, 0))
397
+ new_result_image.paste(result_image, (condition_width + 5, 0))
398
+ return new_result_image
399
+
400
+
401
  def person_example_fn(image_path):
402
  return image_path
403
 
404
+ HEADER = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
405
 
406
  def app_gradio():
407
  with gr.Blocks(title="CatVTON") as demo: