Sffake / app.py
mr2along's picture
Update app.py
8e3bca8 verified
raw
history blame
4.01 kB
import os
import cv2
import time
import torch
import argparse
import insightface
import onnxruntime
import numpy as np
import gradio as gr
from tqdm import tqdm
from face_swapper import Inswapper, paste_to_whole
from face_analyser import analyse_face
from face_enhancer import load_face_enhancer_model, cv2_interpolations
from utils import create_image_grid
## ------------------------------ USER ARGS ------------------------------
parser = argparse.ArgumentParser(description="Free Face Swapper (Male/Female mode)")
parser.add_argument("--out_dir", default=os.getcwd())
parser.add_argument("--batch_size", default=32)
parser.add_argument("--cuda", action="store_true", default=False)
user_args = parser.parse_args()
USE_CUDA = user_args.cuda
DEF_OUTPUT_PATH = user_args.out_dir
BATCH_SIZE = int(user_args.batch_size)
## ------------------------------ DEVICE ------------------------------
PROVIDER = ["CPUExecutionProvider"]
if USE_CUDA:
if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
print(">>> Running on CUDA")
else:
USE_CUDA = False
print(">>> CUDA not available, running on CPU")
device = "cuda" if USE_CUDA else "cpu"
EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None
## ------------------------------ LOAD MODELS ------------------------------
FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.6)
FACE_SWAPPER = Inswapper(
model_file="./assets/pretrained_models/inswapper_128.onnx",
batch_size=(BATCH_SIZE if USE_CUDA else 1),
providers=PROVIDER,
)
## ------------------------------ PROCESS ------------------------------
def swap_faces(image_path, male_source_path, female_source_path, face_enhancer_name="NONE"):
start_time = time.time()
# Load target
target = cv2.imread(image_path)
# Load source male/female
analysed_source_male = analyse_face(cv2.imread(male_source_path), FACE_ANALYSER)
analysed_source_female = analyse_face(cv2.imread(female_source_path), FACE_ANALYSER)
# Analyse target
analysed_faces = FACE_ANALYSER.get(target)
preds, matrs = [], []
for analysed_face in tqdm(analysed_faces, desc="Swapping faces"):
if analysed_face["gender"] == 1: # male
src = analysed_source_male
else: # female
src = analysed_source_female
batch_pred, batch_matr = FACE_SWAPPER.get([target], [analysed_face], [src])
preds.extend(batch_pred)
matrs.extend(batch_matr)
EMPTY_CACHE()
# Paste back
for p, m in zip(preds, matrs):
target = paste_to_whole(p, target, m, blend_method="laplacian")
# Enhance (optional)
if face_enhancer_name != "NONE":
model, runner = load_face_enhancer_model(face_enhancer_name, device=device)
target = runner(target, model)
elapsed = time.time() - start_time
print(f"✔ Done in {elapsed:.2f} sec")
return target[:, :, ::-1] # BGR->RGB for display
## ------------------------------ GRADIO UI ------------------------------
with gr.Blocks() as demo:
gr.Markdown("## 🧑➡👩 Face Swapper (Male+Female sources)")
with gr.Row():
with gr.Column():
image_input = gr.Image(label="Target Image", type="filepath")
male_input = gr.Image(label="Source Male", type="filepath")
female_input = gr.Image(label="Source Female", type="filepath")
enhancer = gr.Dropdown(
["NONE"] + cv2_interpolations, label="Face Enhancer", value="NONE"
)
run_btn = gr.Button("✨ Swap")
with gr.Column():
output_image = gr.Image(label="Output")
run_btn.click(
fn=swap_faces,
inputs=[image_input, male_input, female_input, enhancer],
outputs=output_image,
)
if __name__ == "__main__":
demo.launch()