Sffake / app.py
mr2along's picture
Update app.py
9c9d565 verified
raw
history blame
6.87 kB
import os
import cv2
import time
import torch
import argparse
import insightface
import onnxruntime
import gradio as gr
from tqdm import tqdm
from moviepy.editor import VideoFileClip
from face_swapper import Inswapper, paste_to_whole
from face_analyser import analyse_face
from face_enhancer import load_face_enhancer_model, get_available_enhancer_names
from utils import merge_img_sequence_from_ref
# ------------------------------ ARGS ------------------------------
parser = argparse.ArgumentParser(description="Face Swapper (Multi-target, Male+Female Sources)")
parser.add_argument("--out_dir", default=os.getcwd())
parser.add_argument("--batch_size", default=32)
parser.add_argument("--cuda", action="store_true", default=False)
user_args = parser.parse_args()
USE_CUDA = user_args.cuda
DEF_OUTPUT_PATH = user_args.out_dir
BATCH_SIZE = int(user_args.batch_size)
# ------------------------------ DEVICE ------------------------------
PROVIDER = ["CPUExecutionProvider"]
if USE_CUDA and "CUDAExecutionProvider" in onnxruntime.get_available_providers():
PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
print(">>> Running on CUDA")
else:
USE_CUDA = False
print(">>> Running on CPU")
device = "cuda" if USE_CUDA else "cpu"
EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
# ------------------------------ MODELS ------------------------------
FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.6)
FACE_SWAPPER = Inswapper(
model_file="./assets/pretrained_models/inswapper_128.onnx",
batch_size=(BATCH_SIZE if USE_CUDA else 1),
providers=PROVIDER,
)
# ------------------------------ ENHANCERS ------------------------------
ENHANCER_CHOICES = ["NONE"] + get_available_enhancer_names()
# ------------------------------ CORE SWAP FUNC ------------------------------
def swap_on_frame(frame_bgr, analysed_source_male, analysed_source_female, enhancer_name="NONE"):
analysed_faces = FACE_ANALYSER.get(frame_bgr)
preds, matrs = [], []
for analysed_face in analysed_faces:
src_face = analysed_source_male if analysed_face.get("gender", 1) == 1 else analysed_source_female
batch_pred, batch_matr = FACE_SWAPPER.get([frame_bgr], [analysed_face], [src_face])
preds.extend(batch_pred)
matrs.extend(batch_matr)
EMPTY_CACHE()
for p, m in zip(preds, matrs):
frame_bgr = paste_to_whole(
foreground=p,
background=frame_bgr,
matrix=m,
mask=None,
crop_mask=(0, 0, 0, 0),
blur_amount=0.1,
erode_amount=0.15,
blend_method="laplacian",
)
if enhancer_name != "NONE":
try:
model, runner = load_face_enhancer_model(name=enhancer_name, device=device)
frame_bgr = runner(frame_bgr, model)
except Exception as e:
print(f"[Enhancer] Lỗi khi chạy {enhancer_name}: {e}")
return frame_bgr
# ------------------------------ PROCESS ------------------------------
def swap_faces(target_files, male_file, female_file, enhancer_name="NONE"):
start_time = time.time()
male_source_path = male_file.name
female_source_path = female_file.name
analysed_source_male = analyse_face(cv2.imread(male_source_path), FACE_ANALYSER)
analysed_source_female = analyse_face(cv2.imread(female_source_path), FACE_ANALYSER)
out_images, out_videos = [], []
for f in target_files:
target_path = f.name
ext = os.path.splitext(target_path)[-1].lower()
# -------------------- IMAGE --------------------
if ext in [".jpg", ".jpeg", ".png"]:
frame_bgr = cv2.imread(target_path)
out_frame = swap_on_frame(frame_bgr, analysed_source_male, analysed_source_female, enhancer_name)
out_images.append(out_frame[:, :, ::-1]) # RGB
out_videos.append(None)
# -------------------- VIDEO --------------------
elif ext in [".mp4", ".avi", ".mov"]:
temp_dir = os.path.join(DEF_OUTPUT_PATH, "temp_frames")
os.makedirs(temp_dir, exist_ok=True)
cap = cv2.VideoCapture(target_path)
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
print(f">>> Đang xử lý video {os.path.basename(target_path)}: {frame_count} frames @ {fps:.1f} FPS")
frame_paths = []
idx = 0
while True:
ret, frame = cap.read()
if not ret:
break
swapped = swap_on_frame(frame, analysed_source_male, analysed_source_female, enhancer_name)
frame_path = os.path.join(temp_dir, f"frame_{idx:05d}.jpg")
cv2.imwrite(frame_path, swapped)
frame_paths.append(frame_path)
idx += 1
cap.release()
out_path = os.path.join(DEF_OUTPUT_PATH, f"swapped_{os.path.basename(target_path)}")
merge_img_sequence_from_ref(target_path, frame_paths, out_path)
out_images.append(None)
out_videos.append(out_path)
print(f"✔ Hoàn tất tất cả trong {time.time() - start_time:.2f}s")
return out_images, out_videos
# ------------------------------ UI ------------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🧑‍🦱➡👩 Face Swapper (Upload nhiều file target + nguồn nam/nữ) + Enhancer")
with gr.Row():
with gr.Column():
target_input = gr.File(label="Files đích (ảnh/video)", file_types=[".jpg", ".png", ".mp4"], file_types_allow_multiple=True)
male_input = gr.File(label="File nguồn Nam (ảnh)", file_types=[".jpg", ".png"])
female_input = gr.File(label="File nguồn Nữ (ảnh)", file_types=[".jpg", ".png"])
enhancer = gr.Dropdown(ENHANCER_CHOICES, label="Face Enhancer", value="NONE")
run_btn = gr.Button("✨ Swap")
with gr.Column():
output_gallery = gr.Gallery(label="Kết quả ảnh")
output_videos = gr.File(label="Kết quả video", file_types=[".mp4"], file_types_allow_multiple=True)
def run_wrapper(target_files, male_file, female_file, enhancer_name):
out_imgs, out_vids = swap_faces(target_files, male_file, female_file, enhancer_name)
# Gallery hiển thị list ảnh
imgs = [img for img in out_imgs if img is not None]
vids = [v for v in out_vids if v is not None]
return imgs, vids
run_btn.click(
fn=run_wrapper,
inputs=[target_input, male_input, female_input, enhancer],
outputs=[output_gallery, output_videos],
)
if __name__ == "__main__":
demo.launch()