mr2along commited on
Commit
8e3bca8
·
verified ·
1 Parent(s): 771989b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +66 -840
app.py CHANGED
@@ -1,894 +1,120 @@
1
  import os
2
  import cv2
3
- import glob
4
  import time
5
  import torch
6
- import shutil
7
  import argparse
8
- import platform
9
- import datetime
10
- import subprocess
11
  import insightface
12
  import onnxruntime
13
  import numpy as np
14
  import gradio as gr
15
- import threading
16
- import queue
17
  from tqdm import tqdm
18
- import concurrent.futures
19
- from moviepy.editor import VideoFileClip
20
 
21
  from face_swapper import Inswapper, paste_to_whole
22
- from face_analyser import detect_conditions, get_analysed_data, swap_options_list
23
- from face_parsing import init_parsing_model, get_parsed_mask, mask_regions, mask_regions_to_list
24
- from face_enhancer import get_available_enhancer_names, load_face_enhancer_model, cv2_interpolations
25
- from utils import trim_video, StreamerThread, ProcessBar, open_directory, split_list_by_lengths, merge_img_sequence_from_ref, create_image_grid
26
 
27
  ## ------------------------------ USER ARGS ------------------------------
28
 
29
- parser = argparse.ArgumentParser(description="Free Face Swapper")
30
- parser.add_argument("--out_dir", help="Default Output directory", default=os.getcwd())
31
- parser.add_argument("--batch_size", help="Gpu batch size", default=32)
32
- parser.add_argument("--cuda", action="store_true", help="Enable cuda", default=False)
33
- parser.add_argument(
34
- "--colab", action="store_true", help="Enable colab mode", default=False
35
- )
36
  user_args = parser.parse_args()
37
 
38
- ## ------------------------------ DEFAULTS ------------------------------
39
-
40
- USE_COLAB = user_args.colab
41
  USE_CUDA = user_args.cuda
42
  DEF_OUTPUT_PATH = user_args.out_dir
43
  BATCH_SIZE = int(user_args.batch_size)
44
- WORKSPACE = None
45
- OUTPUT_FILE = None
46
- CURRENT_FRAME = None
47
- STREAMER = None
48
- DETECT_CONDITION = "best detection"
49
- DETECT_SIZE = 640
50
- DETECT_THRESH = 0.6
51
- NUM_OF_SRC_SPECIFIC = 10
52
- MASK_INCLUDE = [
53
- "Skin",
54
- "R-Eyebrow",
55
- "L-Eyebrow",
56
- "L-Eye",
57
- "R-Eye",
58
- "Nose",
59
- "Mouth",
60
- "L-Lip",
61
- "U-Lip"
62
- ]
63
- MASK_SOFT_KERNEL = 17
64
- MASK_SOFT_ITERATIONS = 10
65
- MASK_BLUR_AMOUNT = 0.1
66
- MASK_ERODE_AMOUNT = 0.15
67
 
68
- FACE_SWAPPER = None
69
- FACE_ANALYSER = None
70
- FACE_ENHANCER = None
71
- FACE_PARSER = None
72
- FACE_ENHANCER_LIST = ["NONE"]
73
- FACE_ENHANCER_LIST.extend(get_available_enhancer_names())
74
- FACE_ENHANCER_LIST.extend(cv2_interpolations)
75
-
76
- ## ------------------------------ SET EXECUTION PROVIDER ------------------------------
77
- # Note: Non CUDA users may change settings here
78
 
79
  PROVIDER = ["CPUExecutionProvider"]
80
-
81
  if USE_CUDA:
82
- available_providers = onnxruntime.get_available_providers()
83
- if "CUDAExecutionProvider" in available_providers:
84
- print("\n********** Running on CUDA **********\n")
85
  PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
 
86
  else:
87
  USE_CUDA = False
88
- print("\n********** CUDA unavailable running on CPU **********\n")
89
- else:
90
- USE_CUDA = False
91
- print("\n********** Running on CPU **********\n")
92
 
93
  device = "cuda" if USE_CUDA else "cpu"
94
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
95
 
96
  ## ------------------------------ LOAD MODELS ------------------------------
97
 
98
- def load_face_analyser_model(name="buffalo_l"):
99
- global FACE_ANALYSER
100
- if FACE_ANALYSER is None:
101
- FACE_ANALYSER = insightface.app.FaceAnalysis(name=name, providers=PROVIDER)
102
- FACE_ANALYSER.prepare(
103
- ctx_id=0, det_size=(DETECT_SIZE, DETECT_SIZE), det_thresh=DETECT_THRESH
104
- )
105
-
106
-
107
- def load_face_swapper_model(path="./assets/pretrained_models/inswapper_128.onnx"):
108
- global FACE_SWAPPER
109
- if FACE_SWAPPER is None:
110
- batch = int(BATCH_SIZE) if device == "cuda" else 1
111
- FACE_SWAPPER = Inswapper(model_file=path, batch_size=batch, providers=PROVIDER)
112
-
113
-
114
- def load_face_parser_model(path="./assets/pretrained_models/79999_iter.pth"):
115
- global FACE_PARSER
116
- if FACE_PARSER is None:
117
- FACE_PARSER = init_parsing_model(path, device=device)
118
 
 
 
 
 
 
119
 
120
- load_face_analyser_model()
121
- load_face_swapper_model()
122
-
123
- ## ------------------------------ MAIN PROCESS ------------------------------
124
-
125
-
126
- def process(
127
- input_type,
128
- image_path,
129
- video_path,
130
- directory_path,
131
- source_path,
132
- output_path,
133
- output_name,
134
- keep_output_sequence,
135
- condition,
136
- age,
137
- distance,
138
- face_enhancer_name,
139
- enable_face_parser,
140
- mask_includes,
141
- mask_soft_kernel,
142
- mask_soft_iterations,
143
- blur_amount,
144
- erode_amount,
145
- face_scale,
146
- enable_laplacian_blend,
147
- crop_top,
148
- crop_bott,
149
- crop_left,
150
- crop_right,
151
- *specifics,
152
- ):
153
- global WORKSPACE
154
- global OUTPUT_FILE
155
- global PREVIEW
156
- WORKSPACE, OUTPUT_FILE, PREVIEW = None, None, None
157
-
158
- ## ------------------------------ GUI UPDATE FUNC ------------------------------
159
-
160
- def ui_before():
161
- return (
162
- gr.update(visible=True, value=PREVIEW),
163
- gr.update(interactive=False),
164
- gr.update(interactive=False),
165
- gr.update(visible=False),
166
- )
167
-
168
- def ui_after():
169
- return (
170
- gr.update(visible=True, value=PREVIEW),
171
- gr.update(interactive=True),
172
- gr.update(interactive=True),
173
- gr.update(visible=False),
174
- )
175
-
176
- def ui_after_vid():
177
- return (
178
- gr.update(visible=False),
179
- gr.update(interactive=True),
180
- gr.update(interactive=True),
181
- gr.update(value=OUTPUT_FILE, visible=True),
182
- )
183
 
 
184
  start_time = time.time()
185
- total_exec_time = lambda start_time: divmod(time.time() - start_time, 60)
186
- get_finsh_text = lambda start_time: f"✔️ Completed in {int(total_exec_time(start_time)[0])} min {int(total_exec_time(start_time)[1])} sec."
187
 
188
- ## ------------------------------ PREPARE INPUTS & LOAD MODELS ------------------------------
 
189
 
190
-
 
 
191
 
192
- yield "### \n ⌛ Loading face analyser model...", *ui_before()
193
- load_face_analyser_model()
194
 
195
- yield "### \n ⌛ Loading face swapper model...", *ui_before()
196
- load_face_swapper_model()
 
 
 
 
197
 
198
- if face_enhancer_name != "NONE":
199
- if face_enhancer_name not in cv2_interpolations:
200
- yield f"### \n ⌛ Loading {face_enhancer_name} model...", *ui_before()
201
- FACE_ENHANCER = load_face_enhancer_model(name=face_enhancer_name, device=device)
202
- else:
203
- FACE_ENHANCER = None
204
-
205
- if enable_face_parser:
206
- yield "### \n ⌛ Loading face parsing model...", *ui_before()
207
- load_face_parser_model()
208
-
209
- includes = mask_regions_to_list(mask_includes)
210
- specifics = list(specifics)
211
- half = len(specifics) // 2
212
- sources = specifics[:half]
213
- specifics = specifics[half:]
214
- if crop_top > crop_bott:
215
- crop_top, crop_bott = crop_bott, crop_top
216
- if crop_left > crop_right:
217
- crop_left, crop_right = crop_right, crop_left
218
- crop_mask = (crop_top, 511-crop_bott, crop_left, 511-crop_right)
219
-
220
- def swap_process(image_sequence):
221
- ## ------------------------------ CONTENT CHECK ------------------------------
222
-
223
-
224
- yield "### \n ⌛ Analysing face data...", *ui_before()
225
- if condition != "Specific Face":
226
- source_data = source_path, age
227
- else:
228
- source_data = ((sources, specifics), distance)
229
- analysed_targets, analysed_sources, whole_frame_list, num_faces_per_frame = get_analysed_data(
230
- FACE_ANALYSER,
231
- image_sequence,
232
- source_data,
233
- swap_condition=condition,
234
- detect_condition=DETECT_CONDITION,
235
- scale=face_scale
236
- )
237
-
238
- ## ------------------------------ SWAP FUNC ------------------------------
239
-
240
- yield "### \n ⌛ Generating faces...", *ui_before()
241
- preds = []
242
- matrs = []
243
- count = 0
244
- global PREVIEW
245
- for batch_pred, batch_matr in FACE_SWAPPER.batch_forward(whole_frame_list, analysed_targets, analysed_sources):
246
- preds.extend(batch_pred)
247
- matrs.extend(batch_matr)
248
- EMPTY_CACHE()
249
- count += 1
250
-
251
- if USE_CUDA:
252
- image_grid = create_image_grid(batch_pred, size=128)
253
- PREVIEW = image_grid[:, :, ::-1]
254
- yield f"### \n ⌛ Generating face Batch {count}", *ui_before()
255
-
256
- ## ------------------------------ FACE ENHANCEMENT ------------------------------
257
-
258
- generated_len = len(preds)
259
- if face_enhancer_name != "NONE":
260
- yield f"### \n ⌛ Upscaling faces with {face_enhancer_name}...", *ui_before()
261
- for idx, pred in tqdm(enumerate(preds), total=generated_len, desc=f"Upscaling with {face_enhancer_name}"):
262
- enhancer_model, enhancer_model_runner = FACE_ENHANCER
263
- pred = enhancer_model_runner(pred, enhancer_model)
264
- preds[idx] = cv2.resize(pred, (512,512))
265
  EMPTY_CACHE()
266
 
267
- ## ------------------------------ FACE PARSING ------------------------------
268
-
269
- if enable_face_parser:
270
- yield "### \n ⌛ Face-parsing mask...", *ui_before()
271
- masks = []
272
- count = 0
273
- for batch_mask in get_parsed_mask(FACE_PARSER, preds, classes=includes, device=device, batch_size=BATCH_SIZE, softness=int(mask_soft_iterations)):
274
- masks.append(batch_mask)
275
- EMPTY_CACHE()
276
- count += 1
277
-
278
- if len(batch_mask) > 1:
279
- image_grid = create_image_grid(batch_mask, size=128)
280
- PREVIEW = image_grid[:, :, ::-1]
281
- yield f"### \n ⌛ Face parsing Batch {count}", *ui_before()
282
- masks = np.concatenate(masks, axis=0) if len(masks) >= 1 else masks
283
- else:
284
- masks = [None] * generated_len
285
-
286
- ## ------------------------------ SPLIT LIST ------------------------------
287
-
288
- split_preds = split_list_by_lengths(preds, num_faces_per_frame)
289
- del preds
290
- split_matrs = split_list_by_lengths(matrs, num_faces_per_frame)
291
- del matrs
292
- split_masks = split_list_by_lengths(masks, num_faces_per_frame)
293
- del masks
294
-
295
- ## ------------------------------ PASTE-BACK ------------------------------
296
-
297
- yield "### \n ⌛ Pasting back...", *ui_before()
298
- def post_process(frame_idx, frame_img, split_preds, split_matrs, split_masks, enable_laplacian_blend, crop_mask, blur_amount, erode_amount):
299
- whole_img_path = frame_img
300
- whole_img = cv2.imread(whole_img_path)
301
- blend_method = 'laplacian' if enable_laplacian_blend else 'linear'
302
- for p, m, mask in zip(split_preds[frame_idx], split_matrs[frame_idx], split_masks[frame_idx]):
303
- p = cv2.resize(p, (512,512))
304
- mask = cv2.resize(mask, (512,512)) if mask is not None else None
305
- m /= 0.25
306
- whole_img = paste_to_whole(p, whole_img, m, mask=mask, crop_mask=crop_mask, blend_method=blend_method, blur_amount=blur_amount, erode_amount=erode_amount)
307
- cv2.imwrite(whole_img_path, whole_img)
308
-
309
- def concurrent_post_process(image_sequence, *args):
310
- with concurrent.futures.ThreadPoolExecutor() as executor:
311
- futures = []
312
- for idx, frame_img in enumerate(image_sequence):
313
- future = executor.submit(post_process, idx, frame_img, *args)
314
- futures.append(future)
315
-
316
- for future in tqdm(concurrent.futures.as_completed(futures), total=len(futures), desc="Pasting back"):
317
- result = future.result()
318
-
319
- concurrent_post_process(
320
- image_sequence,
321
- split_preds,
322
- split_matrs,
323
- split_masks,
324
- enable_laplacian_blend,
325
- crop_mask,
326
- blur_amount,
327
- erode_amount
328
- )
329
-
330
-
331
- ## ------------------------------ IMAGE ------------------------------
332
-
333
- if input_type == "Image":
334
- target = cv2.imread(image_path)
335
- output_file = os.path.join(output_path, output_name + ".png")
336
- cv2.imwrite(output_file, target)
337
-
338
- for info_update in swap_process([output_file]):
339
- yield info_update
340
-
341
- OUTPUT_FILE = output_file
342
- WORKSPACE = output_path
343
- PREVIEW = cv2.imread(output_file)[:, :, ::-1]
344
-
345
- yield get_finsh_text(start_time), *ui_after()
346
-
347
- ## ------------------------------ VIDEO ------------------------------
348
-
349
- elif input_type == "Video":
350
- temp_path = os.path.join(output_path, output_name, "sequence")
351
- os.makedirs(temp_path, exist_ok=True)
352
-
353
- yield "### \n ⌛ Extracting video frames...", *ui_before()
354
- image_sequence = []
355
- cap = cv2.VideoCapture(video_path)
356
- curr_idx = 0
357
- while True:
358
- ret, frame = cap.read()
359
- if not ret:break
360
- frame_path = os.path.join(temp_path, f"frame_{curr_idx}.jpg")
361
- cv2.imwrite(frame_path, frame)
362
- image_sequence.append(frame_path)
363
- curr_idx += 1
364
- cap.release()
365
- cv2.destroyAllWindows()
366
-
367
- for info_update in swap_process(image_sequence):
368
- yield info_update
369
-
370
- yield "### \n ⌛ Merging sequence...", *ui_before()
371
- output_video_path = os.path.join(output_path, output_name + ".mp4")
372
- merge_img_sequence_from_ref(video_path, image_sequence, output_video_path)
373
-
374
- if os.path.exists(temp_path) and not keep_output_sequence:
375
- yield "### \n ⌛ Removing temporary files...", *ui_before()
376
- shutil.rmtree(temp_path)
377
-
378
- WORKSPACE = output_path
379
- OUTPUT_FILE = output_video_path
380
-
381
- yield get_finsh_text(start_time), *ui_after_vid()
382
-
383
- ## ------------------------------ DIRECTORY ------------------------------
384
-
385
- elif input_type == "Directory":
386
- extensions = ["jpg", "jpeg", "png", "bmp", "tiff", "ico", "webp"]
387
- temp_path = os.path.join(output_path, output_name)
388
- if os.path.exists(temp_path):
389
- shutil.rmtree(temp_path)
390
- os.mkdir(temp_path)
391
-
392
- file_paths =[]
393
- for file_path in glob.glob(os.path.join(directory_path, "*")):
394
- if any(file_path.lower().endswith(ext) for ext in extensions):
395
- img = cv2.imread(file_path)
396
- new_file_path = os.path.join(temp_path, os.path.basename(file_path))
397
- cv2.imwrite(new_file_path, img)
398
- file_paths.append(new_file_path)
399
-
400
- for info_update in swap_process(file_paths):
401
- yield info_update
402
-
403
- PREVIEW = cv2.imread(file_paths[-1])[:, :, ::-1]
404
- WORKSPACE = temp_path
405
- OUTPUT_FILE = file_paths[-1]
406
-
407
- yield get_finsh_text(start_time), *ui_after()
408
-
409
- ## ------------------------------ STREAM ------------------------------
410
-
411
- elif input_type == "Stream":
412
- pass
413
-
414
-
415
- ## ------------------------------ GRADIO FUNC ------------------------------
416
-
417
-
418
- def update_radio(value):
419
- if value == "Image":
420
- return (
421
- gr.update(visible=True),
422
- gr.update(visible=False),
423
- gr.update(visible=False),
424
- )
425
- elif value == "Video":
426
- return (
427
- gr.update(visible=False),
428
- gr.update(visible=True),
429
- gr.update(visible=False),
430
- )
431
- elif value == "Directory":
432
- return (
433
- gr.update(visible=False),
434
- gr.update(visible=False),
435
- gr.update(visible=True),
436
- )
437
- elif value == "Stream":
438
- return (
439
- gr.update(visible=False),
440
- gr.update(visible=False),
441
- gr.update(visible=True),
442
- )
443
-
444
-
445
- def swap_option_changed(value):
446
- if value.startswith("Age"):
447
- return (
448
- gr.update(visible=True),
449
- gr.update(visible=False),
450
- gr.update(visible=True),
451
- )
452
- elif value == "Specific Face":
453
- return (
454
- gr.update(visible=False),
455
- gr.update(visible=True),
456
- gr.update(visible=False),
457
- )
458
- return gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)
459
-
460
-
461
- def video_changed(video_path):
462
- sliders_update = gr.Slider.update
463
- button_update = gr.Button.update
464
- number_update = gr.Number.update
465
-
466
- if video_path is None:
467
- return (
468
- sliders_update(minimum=0, maximum=0, value=0),
469
- sliders_update(minimum=1, maximum=1, value=1),
470
- number_update(value=1),
471
- )
472
- try:
473
- clip = VideoFileClip(video_path)
474
- fps = clip.fps
475
- total_frames = clip.reader.nframes
476
- clip.close()
477
- return (
478
- sliders_update(minimum=0, maximum=total_frames, value=0, interactive=True),
479
- sliders_update(
480
- minimum=0, maximum=total_frames, value=total_frames, interactive=True
481
- ),
482
- number_update(value=fps),
483
- )
484
- except:
485
- return (
486
- sliders_update(value=0),
487
- sliders_update(value=0),
488
- number_update(value=1),
489
- )
490
-
491
-
492
- def analyse_settings_changed(detect_condition, detection_size, detection_threshold):
493
- yield "### \n ⌛ Applying new values..."
494
- global FACE_ANALYSER
495
- global DETECT_CONDITION
496
- DETECT_CONDITION = detect_condition
497
- FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
498
- FACE_ANALYSER.prepare(
499
- ctx_id=0,
500
- det_size=(int(detection_size), int(detection_size)),
501
- det_thresh=float(detection_threshold),
502
- )
503
- yield f"### \n ✔️ Applied detect condition:{detect_condition}, detection size: {detection_size}, detection threshold: {detection_threshold}"
504
-
505
-
506
- def stop_running():
507
- global STREAMER
508
- if hasattr(STREAMER, "stop"):
509
- STREAMER.stop()
510
- STREAMER = None
511
- return "Cancelled"
512
-
513
-
514
- def slider_changed(show_frame, video_path, frame_index):
515
- if not show_frame:
516
- return None, None
517
- if video_path is None:
518
- return None, None
519
- clip = VideoFileClip(video_path)
520
- frame = clip.get_frame(frame_index / clip.fps)
521
- frame_array = np.array(frame)
522
- clip.close()
523
- return gr.Image.update(value=frame_array, visible=True), gr.Video.update(
524
- visible=False
525
- )
526
 
 
 
 
 
527
 
528
- def trim_and_reload(video_path, output_path, output_name, start_frame, stop_frame):
529
- yield video_path, f"### \n Trimming video frame {start_frame} to {stop_frame}..."
530
- try:
531
- output_path = os.path.join(output_path, output_name)
532
- trimmed_video = trim_video(video_path, output_path, start_frame, stop_frame)
533
- yield trimmed_video, "### \n ✔️ Video trimmed and reloaded."
534
- except Exception as e:
535
- print(e)
536
- yield video_path, "### \n ❌ Video trimming failed. See console for more info."
537
 
538
 
539
- ## ------------------------------ GRADIO GUI ------------------------------
540
 
541
- css = """
542
- footer{display:none !important}
543
- """
544
 
545
- with gr.Blocks(css=css) as interface:
546
- gr.Markdown("# 🗿 Free Face Swapper")
547
- gr.Markdown("### Help us keep this app free with a tip.")
548
  with gr.Row():
549
- with gr.Row():
550
- with gr.Column(scale=0.4):
551
- with gr.Tab("📄 Swap Condition"):
552
- swap_option = gr.Dropdown(
553
- swap_options_list,
554
- info="Choose which face or faces in the target image to swap.",
555
- multiselect=False,
556
- show_label=False,
557
- value=swap_options_list[0],
558
- interactive=True,
559
- )
560
- age = gr.Number(
561
- value=25, label="Value", interactive=True, visible=False
562
- )
563
-
564
- with gr.Tab("🎚️ Detection Settings"):
565
- detect_condition_dropdown = gr.Dropdown(
566
- detect_conditions,
567
- label="Condition",
568
- value=DETECT_CONDITION,
569
- interactive=True,
570
- info="This condition is only used when multiple faces are detected on source or specific image.",
571
- )
572
- detection_size = gr.Number(
573
- label="Detection Size", value=DETECT_SIZE, interactive=True
574
- )
575
- detection_threshold = gr.Number(
576
- label="Detection Threshold",
577
- value=DETECT_THRESH,
578
- interactive=True,
579
- )
580
- apply_detection_settings = gr.Button("Apply settings")
581
-
582
- with gr.Tab("📤 Output Settings"):
583
- output_directory = gr.Text(
584
- label="Output Directory",
585
- value=DEF_OUTPUT_PATH,
586
- interactive=True,
587
- )
588
- output_name = gr.Text(
589
- label="Output Name", value="Result", interactive=True
590
- )
591
- keep_output_sequence = gr.Checkbox(
592
- label="Keep output sequence", value=False, interactive=True
593
- )
594
-
595
- with gr.Tab("🪄 Other Settings"):
596
- face_scale = gr.Slider(
597
- label="Face Scale",
598
- minimum=0,
599
- maximum=2,
600
- value=1,
601
- interactive=True,
602
- )
603
-
604
- face_enhancer_name = gr.Dropdown(
605
- FACE_ENHANCER_LIST, label="Face Enhancer", value="NONE", multiselect=False, interactive=True
606
- )
607
-
608
- with gr.Accordion("Advanced Mask", open=False):
609
- enable_face_parser_mask = gr.Checkbox(
610
- label="Enable Face Parsing",
611
- value=False,
612
- interactive=True,
613
- )
614
-
615
- mask_include = gr.Dropdown(
616
- mask_regions.keys(),
617
- value=MASK_INCLUDE,
618
- multiselect=True,
619
- label="Include",
620
- interactive=True,
621
- )
622
- mask_soft_kernel = gr.Number(
623
- label="Soft Erode Kernel",
624
- value=MASK_SOFT_KERNEL,
625
- minimum=3,
626
- interactive=True,
627
- visible = False
628
- )
629
- mask_soft_iterations = gr.Number(
630
- label="Soft Erode Iterations",
631
- value=MASK_SOFT_ITERATIONS,
632
- minimum=0,
633
- interactive=True,
634
-
635
- )
636
-
637
-
638
- with gr.Accordion("Crop Mask", open=False):
639
- crop_top = gr.Slider(label="Top", minimum=0, maximum=511, value=0, step=1, interactive=True)
640
- crop_bott = gr.Slider(label="Bottom", minimum=0, maximum=511, value=511, step=1, interactive=True)
641
- crop_left = gr.Slider(label="Left", minimum=0, maximum=511, value=0, step=1, interactive=True)
642
- crop_right = gr.Slider(label="Right", minimum=0, maximum=511, value=511, step=1, interactive=True)
643
-
644
-
645
- erode_amount = gr.Slider(
646
- label="Mask Erode",
647
- minimum=0,
648
- maximum=1,
649
- value=MASK_ERODE_AMOUNT,
650
- step=0.05,
651
- interactive=True,
652
- )
653
-
654
- blur_amount = gr.Slider(
655
- label="Mask Blur",
656
- minimum=0,
657
- maximum=1,
658
- value=MASK_BLUR_AMOUNT,
659
- step=0.05,
660
- interactive=True,
661
- )
662
-
663
- enable_laplacian_blend = gr.Checkbox(
664
- label="Laplacian Blending",
665
- value=True,
666
- interactive=True,
667
- )
668
-
669
-
670
- source_image_input_male = gr.Image(label="Source Male Face", type="filepath", interactive=True)
671
- source_image_input_female = gr.Image(label="Source Female Face", type="filepath", interactive=True)
672
-
673
- with gr.Group(visible=False) as specific_face:
674
- for i in range(NUM_OF_SRC_SPECIFIC):
675
- idx = i + 1
676
- code = "\n"
677
- code += f"with gr.Tab(label='({idx})'):"
678
- code += "\n\twith gr.Row():"
679
- code += f"\n\t\tsrc{idx} = gr.Image(interactive=True, type='numpy', label='Source Face {idx}')"
680
- code += f"\n\t\ttrg{idx} = gr.Image(interactive=True, type='numpy', label='Specific Face {idx}')"
681
- exec(code)
682
-
683
- distance_slider = gr.Slider(
684
- minimum=0,
685
- maximum=2,
686
- value=0.6,
687
- interactive=True,
688
- label="Distance",
689
- info="Lower distance is more similar and higher distance is less similar to the target face.",
690
- )
691
-
692
- with gr.Group():
693
- input_type = gr.Radio(
694
- ["Image", "Video"],
695
- label="Target Type",
696
- value="Image",
697
- )
698
-
699
- with gr.Group(visible=True) as input_image_group:
700
- image_input = gr.Image(
701
- label="Target Image", interactive=True, type="filepath"
702
- )
703
-
704
- with gr.Group(visible=False) as input_video_group:
705
- vid_widget = gr.Video if USE_COLAB else gr.Text
706
- video_input = gr.Video(
707
- label="Target Video", interactive=True
708
- )
709
- with gr.Accordion("✂️ Trim video", open=False):
710
- with gr.Column():
711
- with gr.Row():
712
- set_slider_range_btn = gr.Button(
713
- "Set frame range", interactive=True
714
- )
715
- show_trim_preview_btn = gr.Checkbox(
716
- label="Show frame when slider change",
717
- value=True,
718
- interactive=True,
719
- )
720
-
721
- video_fps = gr.Number(
722
- value=30,
723
- interactive=False,
724
- label="Fps",
725
- visible=False,
726
- )
727
- start_frame = gr.Slider(
728
- minimum=0,
729
- maximum=1,
730
- value=0,
731
- step=1,
732
- interactive=True,
733
- label="Start Frame",
734
- info="",
735
- )
736
- end_frame = gr.Slider(
737
- minimum=0,
738
- maximum=1,
739
- value=1,
740
- step=1,
741
- interactive=True,
742
- label="End Frame",
743
- info="",
744
- )
745
- trim_and_reload_btn = gr.Button(
746
- "Trim and Reload", interactive=True
747
- )
748
-
749
- with gr.Group(visible=False) as input_directory_group:
750
- direc_input = gr.Text(label="Path", interactive=True)
751
-
752
- with gr.Column(scale=0.6):
753
- info = gr.Markdown(value="...")
754
-
755
- with gr.Row():
756
- swap_button = gr.Button("✨ Swap", variant="primary")
757
- cancel_button = gr.Button("⛔ Cancel")
758
-
759
- preview_image = gr.Image(label="Output", interactive=False)
760
- preview_video = gr.Video(
761
- label="Output", interactive=False, visible=False
762
- )
763
-
764
- with gr.Row():
765
- output_directory_button = gr.Button(
766
- "📂", interactive=False, visible=False
767
- )
768
- output_video_button = gr.Button(
769
- "🎬", interactive=False, visible=False
770
- )
771
-
772
- with gr.Group():
773
- with gr.Row():
774
- gr.Markdown(
775
- "### [🤝 Enjoying? Help us keep it free with a tip 🤗](https://www.paypal.com/donate/?hosted_button_id=WUWBM97N8EENN)"
776
- )
777
-
778
-
779
- ## ------------------------------ GRADIO EVENTS ------------------------------
780
-
781
- set_slider_range_event = set_slider_range_btn.click(
782
- video_changed,
783
- inputs=[video_input],
784
- outputs=[start_frame, end_frame, video_fps],
785
- )
786
-
787
- trim_and_reload_event = trim_and_reload_btn.click(
788
- fn=trim_and_reload,
789
- inputs=[video_input, output_directory, output_name, start_frame, end_frame],
790
- outputs=[video_input, info],
791
- )
792
-
793
- start_frame_event = start_frame.release(
794
- fn=slider_changed,
795
- inputs=[show_trim_preview_btn, video_input, start_frame],
796
- outputs=[preview_image, preview_video],
797
- show_progress=True,
798
- )
799
-
800
- end_frame_event = end_frame.release(
801
- fn=slider_changed,
802
- inputs=[show_trim_preview_btn, video_input, end_frame],
803
- outputs=[preview_image, preview_video],
804
- show_progress=True,
805
- )
806
-
807
- input_type.change(
808
- update_radio,
809
- inputs=[input_type],
810
- outputs=[input_image_group, input_video_group, input_directory_group],
811
- )
812
- swap_option.change(
813
- swap_option_changed,
814
- inputs=[swap_option],
815
- outputs=[age, specific_face, source_image_input],
816
- )
817
-
818
- apply_detection_settings.click(
819
- analyse_settings_changed,
820
- inputs=[detect_condition_dropdown, detection_size, detection_threshold],
821
- outputs=[info],
822
- )
823
-
824
- src_specific_inputs = []
825
- gen_variable_txt = ",".join(
826
- [f"src{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
827
- + [f"trg{i+1}" for i in range(NUM_OF_SRC_SPECIFIC)]
828
- )
829
- exec(f"src_specific_inputs = ({gen_variable_txt})")
830
- swap_inputs = [
831
- input_type,
832
- image_input,
833
- video_input,
834
- direc_input,
835
- source_image_input,
836
- output_directory,
837
- output_name,
838
- keep_output_sequence,
839
- swap_option,
840
- age,
841
- distance_slider,
842
- face_enhancer_name,
843
- enable_face_parser_mask,
844
- mask_include,
845
- mask_soft_kernel,
846
- mask_soft_iterations,
847
- blur_amount,
848
- erode_amount,
849
- face_scale,
850
- enable_laplacian_blend,
851
- crop_top,
852
- crop_bott,
853
- crop_left,
854
- crop_right,
855
- *src_specific_inputs,
856
- ]
857
-
858
- swap_outputs = [
859
- info,
860
- preview_image,
861
- output_directory_button,
862
- output_video_button,
863
- preview_video,
864
- ]
865
-
866
- swap_event = swap_button.click(
867
- fn=process, inputs=swap_inputs, outputs=swap_outputs, show_progress=True
868
- )
869
-
870
- cancel_button.click(
871
- fn=stop_running,
872
- inputs=None,
873
- outputs=[info],
874
- cancels=[
875
- swap_event,
876
- trim_and_reload_event,
877
- set_slider_range_event,
878
- start_frame_event,
879
- end_frame_event,
880
- ],
881
- show_progress=True,
882
- )
883
- output_directory_button.click(
884
- lambda: open_directory(path=WORKSPACE), inputs=None, outputs=None
885
- )
886
- output_video_button.click(
887
- lambda: open_directory(path=OUTPUT_FILE), inputs=None, outputs=None
888
  )
889
 
890
  if __name__ == "__main__":
891
- if USE_COLAB:
892
- print("Running in colab mode")
893
-
894
- interface.queue( max_size=20).launch(share=USE_COLAB)
 
1
  import os
2
  import cv2
 
3
  import time
4
  import torch
 
5
  import argparse
 
 
 
6
  import insightface
7
  import onnxruntime
8
  import numpy as np
9
  import gradio as gr
 
 
10
  from tqdm import tqdm
 
 
11
 
12
  from face_swapper import Inswapper, paste_to_whole
13
+ from face_analyser import analyse_face
14
+ from face_enhancer import load_face_enhancer_model, cv2_interpolations
15
+ from utils import create_image_grid
 
16
 
17
  ## ------------------------------ USER ARGS ------------------------------
18
 
19
+ parser = argparse.ArgumentParser(description="Free Face Swapper (Male/Female mode)")
20
+ parser.add_argument("--out_dir", default=os.getcwd())
21
+ parser.add_argument("--batch_size", default=32)
22
+ parser.add_argument("--cuda", action="store_true", default=False)
 
 
 
23
  user_args = parser.parse_args()
24
 
 
 
 
25
  USE_CUDA = user_args.cuda
26
  DEF_OUTPUT_PATH = user_args.out_dir
27
  BATCH_SIZE = int(user_args.batch_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
+ ## ------------------------------ DEVICE ------------------------------
 
 
 
 
 
 
 
 
 
30
 
31
  PROVIDER = ["CPUExecutionProvider"]
 
32
  if USE_CUDA:
33
+ if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
 
 
34
  PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
35
+ print(">>> Running on CUDA")
36
  else:
37
  USE_CUDA = False
38
+ print(">>> CUDA not available, running on CPU")
 
 
 
39
 
40
  device = "cuda" if USE_CUDA else "cpu"
41
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
42
 
43
  ## ------------------------------ LOAD MODELS ------------------------------
44
 
45
+ FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
46
+ FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.6)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
+ FACE_SWAPPER = Inswapper(
49
+ model_file="./assets/pretrained_models/inswapper_128.onnx",
50
+ batch_size=(BATCH_SIZE if USE_CUDA else 1),
51
+ providers=PROVIDER,
52
+ )
53
 
54
+ ## ------------------------------ PROCESS ------------------------------
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
+ def swap_faces(image_path, male_source_path, female_source_path, face_enhancer_name="NONE"):
57
  start_time = time.time()
 
 
58
 
59
+ # Load target
60
+ target = cv2.imread(image_path)
61
 
62
+ # Load source male/female
63
+ analysed_source_male = analyse_face(cv2.imread(male_source_path), FACE_ANALYSER)
64
+ analysed_source_female = analyse_face(cv2.imread(female_source_path), FACE_ANALYSER)
65
 
66
+ # Analyse target
67
+ analysed_faces = FACE_ANALYSER.get(target)
68
 
69
+ preds, matrs = [], []
70
+ for analysed_face in tqdm(analysed_faces, desc="Swapping faces"):
71
+ if analysed_face["gender"] == 1: # male
72
+ src = analysed_source_male
73
+ else: # female
74
+ src = analysed_source_female
75
 
76
+ batch_pred, batch_matr = FACE_SWAPPER.get([target], [analysed_face], [src])
77
+ preds.extend(batch_pred)
78
+ matrs.extend(batch_matr)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  EMPTY_CACHE()
80
 
81
+ # Paste back
82
+ for p, m in zip(preds, matrs):
83
+ target = paste_to_whole(p, target, m, blend_method="laplacian")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
+ # Enhance (optional)
86
+ if face_enhancer_name != "NONE":
87
+ model, runner = load_face_enhancer_model(face_enhancer_name, device=device)
88
+ target = runner(target, model)
89
 
90
+ elapsed = time.time() - start_time
91
+ print(f" Done in {elapsed:.2f} sec")
92
+ return target[:, :, ::-1] # BGR->RGB for display
 
 
 
 
 
 
93
 
94
 
95
+ ## ------------------------------ GRADIO UI ------------------------------
96
 
97
+ with gr.Blocks() as demo:
98
+ gr.Markdown("## 🧑➡👩 Face Swapper (Male+Female sources)")
 
99
 
 
 
 
100
  with gr.Row():
101
+ with gr.Column():
102
+ image_input = gr.Image(label="Target Image", type="filepath")
103
+ male_input = gr.Image(label="Source Male", type="filepath")
104
+ female_input = gr.Image(label="Source Female", type="filepath")
105
+ enhancer = gr.Dropdown(
106
+ ["NONE"] + cv2_interpolations, label="Face Enhancer", value="NONE"
107
+ )
108
+ run_btn = gr.Button("✨ Swap")
109
+
110
+ with gr.Column():
111
+ output_image = gr.Image(label="Output")
112
+
113
+ run_btn.click(
114
+ fn=swap_faces,
115
+ inputs=[image_input, male_input, female_input, enhancer],
116
+ outputs=output_image,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
118
 
119
  if __name__ == "__main__":
120
+ demo.launch()