File size: 4,014 Bytes
c2a191e
 
 
 
 
 
 
 
 
 
 
ba9144c
8e3bca8
 
 
c2a191e
 
 
8e3bca8
 
 
 
c2a191e
 
 
 
 
 
8e3bca8
c2a191e
 
 
8e3bca8
c2a191e
8e3bca8
c2a191e
 
8e3bca8
c2a191e
ba9144c
 
 
c2a191e
ba9144c
8e3bca8
 
ba9144c
8e3bca8
 
 
 
 
ba9144c
8e3bca8
ba9144c
8e3bca8
5269076
c2a191e
8e3bca8
 
c2a191e
8e3bca8
 
 
ba9144c
8e3bca8
 
ba9144c
8e3bca8
 
 
 
 
 
c2a191e
8e3bca8
 
 
c2a191e
 
8e3bca8
 
 
ba9144c
8e3bca8
 
 
 
ba9144c
8e3bca8
 
 
ba9144c
 
8e3bca8
fdd0003
8e3bca8
 
fdd0003
ba9144c
8e3bca8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c2a191e
ba9144c
 
8e3bca8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import cv2
import time
import torch
import argparse
import insightface
import onnxruntime
import numpy as np
import gradio as gr
from tqdm import tqdm

from face_swapper import Inswapper, paste_to_whole
from face_analyser import analyse_face
from face_enhancer import load_face_enhancer_model, cv2_interpolations
from utils import create_image_grid

## ------------------------------ USER ARGS ------------------------------

parser = argparse.ArgumentParser(description="Free Face Swapper (Male/Female mode)")
parser.add_argument("--out_dir", default=os.getcwd())
parser.add_argument("--batch_size", default=32)
parser.add_argument("--cuda", action="store_true", default=False)
user_args = parser.parse_args()

USE_CUDA = user_args.cuda
DEF_OUTPUT_PATH = user_args.out_dir
BATCH_SIZE = int(user_args.batch_size)

## ------------------------------ DEVICE ------------------------------

PROVIDER = ["CPUExecutionProvider"]
if USE_CUDA:
    if "CUDAExecutionProvider" in onnxruntime.get_available_providers():
        PROVIDER = ["CUDAExecutionProvider", "CPUExecutionProvider"]
        print(">>> Running on CUDA")
    else:
        USE_CUDA = False
        print(">>> CUDA not available, running on CPU")

device = "cuda" if USE_CUDA else "cpu"
EMPTY_CACHE = lambda: torch.cuda.empty_cache() if device == "cuda" else None

## ------------------------------ LOAD MODELS ------------------------------

FACE_ANALYSER = insightface.app.FaceAnalysis(name="buffalo_l", providers=PROVIDER)
FACE_ANALYSER.prepare(ctx_id=0, det_size=(640, 640), det_thresh=0.6)

FACE_SWAPPER = Inswapper(
    model_file="./assets/pretrained_models/inswapper_128.onnx",
    batch_size=(BATCH_SIZE if USE_CUDA else 1),
    providers=PROVIDER,
)

## ------------------------------ PROCESS ------------------------------

def swap_faces(image_path, male_source_path, female_source_path, face_enhancer_name="NONE"):
    start_time = time.time()

    # Load target
    target = cv2.imread(image_path)

    # Load source male/female
    analysed_source_male = analyse_face(cv2.imread(male_source_path), FACE_ANALYSER)
    analysed_source_female = analyse_face(cv2.imread(female_source_path), FACE_ANALYSER)

    # Analyse target
    analysed_faces = FACE_ANALYSER.get(target)

    preds, matrs = [], []
    for analysed_face in tqdm(analysed_faces, desc="Swapping faces"):
        if analysed_face["gender"] == 1:  # male
            src = analysed_source_male
        else:  # female
            src = analysed_source_female

        batch_pred, batch_matr = FACE_SWAPPER.get([target], [analysed_face], [src])
        preds.extend(batch_pred)
        matrs.extend(batch_matr)
        EMPTY_CACHE()

    # Paste back
    for p, m in zip(preds, matrs):
        target = paste_to_whole(p, target, m, blend_method="laplacian")

    # Enhance (optional)
    if face_enhancer_name != "NONE":
        model, runner = load_face_enhancer_model(face_enhancer_name, device=device)
        target = runner(target, model)

    elapsed = time.time() - start_time
    print(f"✔ Done in {elapsed:.2f} sec")
    return target[:, :, ::-1]  # BGR->RGB for display


## ------------------------------ GRADIO UI ------------------------------

with gr.Blocks() as demo:
    gr.Markdown("## 🧑➡👩 Face Swapper (Male+Female sources)")

    with gr.Row():
        with gr.Column():
            image_input = gr.Image(label="Target Image", type="filepath")
            male_input = gr.Image(label="Source Male", type="filepath")
            female_input = gr.Image(label="Source Female", type="filepath")
            enhancer = gr.Dropdown(
                ["NONE"] + cv2_interpolations, label="Face Enhancer", value="NONE"
            )
            run_btn = gr.Button("✨ Swap")

        with gr.Column():
            output_image = gr.Image(label="Output")

    run_btn.click(
        fn=swap_faces,
        inputs=[image_input, male_input, female_input, enhancer],
        outputs=output_image,
    )

if __name__ == "__main__":
    demo.launch()