mr2along commited on
Commit
3f532f2
·
verified ·
1 Parent(s): 8e3bca8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +67 -55
app.py CHANGED
@@ -5,18 +5,18 @@ import torch
5
  import argparse
6
  import insightface
7
  import onnxruntime
8
- import numpy as np
9
  import gradio as gr
10
  from tqdm import tqdm
11
 
12
  from face_swapper import Inswapper, paste_to_whole
13
  from face_analyser import analyse_face
14
- from face_enhancer import load_face_enhancer_model, cv2_interpolations
15
- from utils import create_image_grid
16
-
17
- ## ------------------------------ USER ARGS ------------------------------
18
 
19
- parser = argparse.ArgumentParser(description="Free Face Swapper (Male/Female mode)")
 
20
  parser.add_argument("--out_dir", default=os.getcwd())
21
  parser.add_argument("--batch_size", default=32)
22
  parser.add_argument("--cuda", action="store_true", default=False)
@@ -26,22 +26,19 @@ USE_CUDA = user_args.cuda
26
  DEF_OUTPUT_PATH = user_args.out_dir
27
  BATCH_SIZE = int(user_args.batch_size)
28
 
29
- ## ------------------------------ DEVICE ------------------------------
30
-
31
  PROVIDER = ["CPUExecutionProvider"]
32
- if USE_CUDA:
33
- if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
34
- PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
35
- print(">>> Running on CUDA")
36
- else:
37
- USE_CUDA = False
38
- print(">>> CUDA not available, running on CPU")
39
 
40
  device = "cuda" if USE_CUDA else "cpu"
41
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
42
 
43
- ## ------------------------------ LOAD MODELS ------------------------------
44
-
45
  FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
46
  FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.6)
47
 
@@ -51,64 +48,79 @@ FACE_SWAPPER = Inswapper(
51
  providers=PROVIDER,
52
  )
53
 
54
- ## ------------------------------ PROCESS ------------------------------
 
 
55
 
56
- def swap_faces(image_path, male_source_path, female_source_path, face_enhancer_name="NONE"):
 
57
  start_time = time.time()
58
 
59
- # Load target
60
- target = cv2.imread(image_path)
 
 
 
 
 
 
 
61
 
62
- # Load source male/female
63
- analysed_source_male = analyse_face(cv2.imread(male_source_path), FACE_ANALYSER)
64
- analysed_source_female = analyse_face(cv2.imread(female_source_path), FACE_ANALYSER)
65
 
66
- # Analyse target
67
- analysed_faces = FACE_ANALYSER.get(target)
68
 
69
  preds, matrs = [], []
70
  for analysed_face in tqdm(analysed_faces, desc="Swapping faces"):
71
- if analysed_face["gender"] == 1: # male
72
- src = analysed_source_male
73
- else: # female
74
- src = analysed_source_female
75
-
76
- batch_pred, batch_matr = FACE_SWAPPER.get([target], [analysed_face], [src])
77
  preds.extend(batch_pred)
78
  matrs.extend(batch_matr)
79
  EMPTY_CACHE()
80
 
81
- # Paste back
82
  for p, m in zip(preds, matrs):
83
- target = paste_to_whole(p, target, m, blend_method="laplacian")
84
-
85
- # Enhance (optional)
86
- if face_enhancer_name != "NONE":
87
- model, runner = load_face_enhancer_model(face_enhancer_name, device=device)
88
- target = runner(target, model)
89
-
90
- elapsed = time.time() - start_time
91
- print(f" Done in {elapsed:.2f} sec")
92
- return target[:, :, ::-1] # BGR->RGB for display
93
-
94
-
95
- ## ------------------------------ GRADIO UI ------------------------------
96
-
 
 
 
 
 
 
 
 
 
 
 
97
  with gr.Blocks() as demo:
98
- gr.Markdown("## 🧑➡👩 Face Swapper (Male+Female sources)")
99
 
100
  with gr.Row():
101
  with gr.Column():
102
- image_input = gr.Image(label="Target Image", type="filepath")
103
- male_input = gr.Image(label="Source Male", type="filepath")
104
- female_input = gr.Image(label="Source Female", type="filepath")
105
- enhancer = gr.Dropdown(
106
- ["NONE"] + cv2_interpolations, label="Face Enhancer", value="NONE"
107
- )
108
  run_btn = gr.Button("✨ Swap")
109
 
110
  with gr.Column():
111
- output_image = gr.Image(label="Output")
112
 
113
  run_btn.click(
114
  fn=swap_faces,
 
5
  import argparse
6
  import insightface
7
  import onnxruntime
 
8
  import gradio as gr
9
  from tqdm import tqdm
10
 
11
  from face_swapper import Inswapper, paste_to_whole
12
  from face_analyser import analyse_face
13
+ from face_enhancer import (
14
+ load_face_enhancer_model,
15
+ get_available_enhancer_names,
16
+ )
17
 
18
+ # ------------------------------ ARGS ------------------------------
19
+ parser = argparse.ArgumentParser(description="Face Swapper (Male+Female with Enhancers)")
20
  parser.add_argument("--out_dir", default=os.getcwd())
21
  parser.add_argument("--batch_size", default=32)
22
  parser.add_argument("--cuda", action="store_true", default=False)
 
26
  DEF_OUTPUT_PATH = user_args.out_dir
27
  BATCH_SIZE = int(user_args.batch_size)
28
 
29
+ # ------------------------------ DEVICE ------------------------------
 
30
  PROVIDER = ["CPUExecutionProvider"]
31
+ if USE_CUDA and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
32
+ PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
33
+ print(">>> Running on CUDA")
34
+ else:
35
+ USE_CUDA = False
36
+ print(">>> Running on CPU")
 
37
 
38
  device = "cuda" if USE_CUDA else "cpu"
39
  EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
40
 
41
+ # ------------------------------ MODELS ------------------------------
 
42
  FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
43
  FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.6)
44
 
 
48
  providers=PROVIDER,
49
  )
50
 
51
+ # ------------------------------ ENHANCERS ------------------------------
52
+ ENHANCER_CHOICES = ["NONE"] + get_available_enhancer_names()
53
+ # ví dụ: ["NONE", "CodeFormer", "GFPGAN", "REAL-ESRGAN 2x", "REAL-ESRGAN 4x", "REAL-ESRGAN 8x"]
54
 
55
+ # ------------------------------ PROCESS ------------------------------
56
+ def swap_faces(image_path, male_source_path, female_source_path, enhancer_name="NONE"):
57
  start_time = time.time()
58
 
59
+ # Load target & sources
60
+ target_bgr = cv2.imread(image_path)
61
+ if target_bgr is None:
62
+ raise ValueError("Không đọc được ảnh đích.")
63
+
64
+ src_male_img = cv2.imread(male_source_path)
65
+ src_female_img = cv2.imread(female_source_path)
66
+ if src_male_img is None or src_female_img is None:
67
+ raise ValueError("Không đọc được ảnh nguồn (nam/nữ).")
68
 
69
+ analysed_source_male = analyse_face(src_male_img, FACE_ANALYSER)
70
+ analysed_source_female = analyse_face(src_female_img, FACE_ANALYSER)
 
71
 
72
+ # Phân tích các khuôn mặt trong ảnh đích
73
+ analysed_faces = FACE_ANALYSER.get(target_bgr)
74
 
75
  preds, matrs = [], []
76
  for analysed_face in tqdm(analysed_faces, desc="Swapping faces"):
77
+ # gender: 1 = male, 0 = female (theo insightface)
78
+ src_face = analysed_source_male if analysed_face.get("gender", 1) == 1 else analysed_source_female
79
+ batch_pred, batch_matr = FACE_SWAPPER.get([target_bgr], [analysed_face], [src_face])
 
 
 
80
  preds.extend(batch_pred)
81
  matrs.extend(batch_matr)
82
  EMPTY_CACHE()
83
 
84
+ # Ghép lại
85
  for p, m in zip(preds, matrs):
86
+ target_bgr = paste_to_whole(
87
+ foreground=p,
88
+ background=target_bgr,
89
+ matrix=m,
90
+ mask=None,
91
+ crop_mask=(0, 0, 0, 0),
92
+ blur_amount=0.1,
93
+ erode_amount=0.15,
94
+ blend_method="laplacian" # tự nhiên hơn
95
+ )
96
+
97
+ # Enhance (nếu chọn)
98
+ if enhancer_name != "NONE":
99
+ try:
100
+ model, runner = load_face_enhancer_model(name=enhancer_name, device=device)
101
+ target_bgr = runner(target_bgr, model)
102
+ except AssertionError as e:
103
+ print(f"[Enhancer] {e}. Trả về ảnh không enhance.")
104
+ except Exception as e:
105
+ print(f"[Enhancer] Lỗi khi chạy {enhancer_name}: {e}. Trả về ảnh không enhance.")
106
+
107
+ print(f"✔ Hoàn tất trong {time.time() - start_time:.2f}s")
108
+ return target_bgr[:, :, ::-1] # BGR -> RGB để hiển thị
109
+
110
+ # ------------------------------ UI ------------------------------
111
  with gr.Blocks() as demo:
112
+ gr.Markdown("## 🧑‍🦱➡👩 Face Swapper (2 nguồn nam/nữ) + Enhancer (CodeFormer / GFPGAN / Real-ESRGAN)")
113
 
114
  with gr.Row():
115
  with gr.Column():
116
+ image_input = gr.Image(label="Ảnh đích (Target Image)", type="filepath")
117
+ male_input = gr.Image(label="Ảnh nguồn Nam", type="filepath")
118
+ female_input = gr.Image(label="Ảnh nguồn Nữ", type="filepath")
119
+ enhancer = gr.Dropdown(ENHANCER_CHOICES, label="Face Enhancer", value="NONE")
 
 
120
  run_btn = gr.Button("✨ Swap")
121
 
122
  with gr.Column():
123
+ output_image = gr.Image(label="Kết quả")
124
 
125
  run_btn.click(
126
  fn=swap_faces,