File size: 4,241 Bytes
2a5630b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# 首先,确保安装了必要的库
# 你可以使用以下命令安装:
# pip install gradio diffusers transformers torch

import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import requests
from io import BytesIO

# 定义可用的扩散模型列表
AVAILABLE_MODELS = {
    "Stable Diffusion v1.4": "CompVis/stable-diffusion-v1-4",
    "Stable Diffusion v1.5": "runwayml/stable-diffusion-v1-5",
    "Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
    # 你可以根据需要添加更多模型
}

# 示例图片的URL列表
SAMPLE_IMAGES = {
    "风景": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/landscape.jpg",
    "人像": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/portrait.jpg",
    "动物": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/animal.jpg",
}

# 使用缓存来存储已加载的模型,以避免重复加载
model_cache = {}

def load_model(model_name):
    if model_name in model_cache:
        return model_cache[model_name]
    else:
        model_id = AVAILABLE_MODELS[model_name]
        pipe = StableDiffusionPipeline.from_pretrained(
            model_id,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
        )
        pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu")
        model_cache[model_name] = pipe
        return pipe

def process_image(model_name, input_image, sample_choice):
    # 如果用户选择使用示例图片,则下载示例图片
    if sample_choice != "上传图片":
        url = SAMPLE_IMAGES.get(sample_choice, SAMPLE_IMAGES["风景"])
        response = requests.get(url)
        input_image = Image.open(BytesIO(response.content)).convert("RGB")
    
    # 加载所选模型
    pipe = load_model(model_name)
    
    # 生成图像(这里以文本提示为例,可以根据实际模型功能调整)
    prompt = "A transformed version of the input image."
    
    with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
        generated_image = pipe(prompt=prompt, init_image=input_image, strength=0.8).images[0]
    
    return input_image, generated_image

# 定义 Gradio 接口
def main():
    with gr.Blocks() as demo:
        gr.Markdown("# Diffusers 扩散模型展示页面")
        gr.Markdown("选择一个模型,上传一张图片或选择一个示例图片,然后点击转换按钮查看结果。")
        
        with gr.Row():
            model_dropdown = gr.Dropdown(
                choices=list(AVAILABLE_MODELS.keys()),
                value=list(AVAILABLE_MODELS.keys())[0],
                label="选择模型"
            )
        
        with gr.Row():
            sample_radio = gr.Radio(
                choices=["上传图片"] + list(SAMPLE_IMAGES.keys()),
                value="上传图片",
                label="选择图片来源"
            )
        
        with gr.Row():
            input_image = gr.Image(
                type="pil",
                label="上传图片",
                visible=False
            )
            sample_image = gr.Image(
                type="pil",
                label="示例图片",
                visible=False
            )
        
        # 根据用户选择显示上传或示例图片
        def toggle_image(choice):
            return {
                "input_image": gr.update(visible=(choice == "上传图片")),
                "sample_image": gr.update(visible=(choice != "上传图片"))
            }
        
        sample_radio.change(toggle_image, inputs=sample_radio, outputs=[input_image, sample_image])
        
        convert_button = gr.Button("转换")
        
        with gr.Row():
            original_output = gr.Image(label="原图")
            generated_output = gr.Image(label="生成图")
        
        convert_button.click(
            process_image,
            inputs=[model_dropdown, input_image, sample_radio],
            outputs=[original_output, generated_output]
        )
    
    demo.launch(server_port=16006)

if __name__ == "__main__":
    main()