File size: 4,449 Bytes
f4cd92d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import os
import webbrowser
import customtkinter as ctk
from typing import Callable, Tuple
import cv2
from PIL import Image, ImageOps
import tkinterdnd2 as tkdnd
import gradio as gr
import traceback
import logging
import io

import modules.globals
import modules.metadata
from modules.face_analyser import (
    get_one_face,
    get_unique_faces_from_target_image,
    get_unique_faces_from_target_video,
    add_blank_map,
    has_valid_map,
    simplify_maps,
)
from modules.capturer import get_video_frame, get_video_frame_total
from modules.processors.frame.core import get_frame_processors_modules
from modules.utilities import (
    is_image,
    is_video,
    resolve_relative_path,
    has_image_extension,
)
import gradio as gr

# 创建一个StringIO对象来捕获日志
log_capture_string = io.StringIO()
logging.basicConfig(stream=log_capture_string, level=logging.INFO, 
                    format='%(asctime)s - %(levelname)s - %(message)s')

def create_ui(start, destroy):
    # 使用gr.State来存储文件路径,给予初始值
    source_path = gr.State(value="")
    target_path = gr.State(value="")

    def process(src_path, tgt_path):
        """
        处理源图像和目标图像/视频
        
        参数:
        src_path (str): 源图像路径
        tgt_path (str): 目标图像/视频路径
        
        返回:
        tuple: (处理结果信息, 错误日志)
        """
        try:
            if src_path and tgt_path:
                logging.info(f"源路径: {src_path}")
                logging.info(f"目标路径: {tgt_path}")
                modules.globals.source_path = src_path
                modules.globals.target_path = tgt_path
                logging.info("开始处理...")
                start()
                return "处理完成", error_log_capture.getvalue()
            else:
                return "请先选择源图像和目标图像/视频", ""
        except Exception as e:
            error_msg = f"处理过程中出错: {str(e)}\n{traceback.format_exc()}"
            logging.error(error_msg)
            return "处理失败", error_log_capture.getvalue()

    def update_source(image):
        """
        更新源图像
        
        参数:
        image (PIL.Image): 上传的图像
        
        返回:
        str: 更新后的源图像路径
        """
        if image is not None:
            temp_path = "temp_source.png"
            image.save(temp_path)
            return temp_path
        return ""

    def update_target(file):
        """
        更新目标文件并生成预览
        
        参数:
        file (UploadedFile): 上传的文件对象
        
        返回:
        tuple: (文件路径, 预览图像路径, 预览可见性更新)
        """
        if file is not None:
            file_path = file.name
            if is_image(file_path):
                return file_path, file_path, gr.update(visible=True)
            elif is_video(file_path):
                video = cv2.VideoCapture(file_path)
                success, frame = video.read()
                if success:
                    preview_path = "temp_preview.jpg"
                    cv2.imwrite(preview_path, frame)
                    video.release()
                    return file_path, preview_path, gr.update(visible=True)
        return "", None, gr.update(visible=False)

    # 创建Gradio界面
    with gr.Blocks() as demo:
        gr.Markdown("# 人脸交换")
        
        with gr.Row():
            source_image = gr.Image(label="源图像", type="pil", elem_id="source_image")
            target_file = gr.File(label="目标图像/视频", elem_id="target_file")
        
        target_preview = gr.Image(label="目标预览", visible=False, elem_id="target_preview")
        
        process_btn = gr.Button("开始处理", elem_id="process_btn")
        
        output = gr.Textbox(label="输出信息", elem_id="output")
        
        # 错误日志显示区域
        error_log_output = gr.Textbox(label="错误日志", lines=5, elem_id="error_log")
        
        # 设置组件事件
        source_image.change(update_source, inputs=[source_image], outputs=[source_path])
        target_file.change(update_target, inputs=[target_file], outputs=[target_path, target_preview, target_preview])
        process_btn.click(process, inputs=[source_path, target_path], outputs=[output, error_log_output])

    return demo