S2O_DPM / lcm.py
Mayuri's picture
Upload 10 files
2a5630b verified
# 首先,确保安装了必要的库
# 你可以使用以下命令安装:
# pip install gradio diffusers transformers torch
import gradio as gr
from diffusers import StableDiffusionPipeline
import torch
from PIL import Image
import requests
from io import BytesIO
# 定义可用的扩散模型列表
AVAILABLE_MODELS = {
"Stable Diffusion v1.4": "CompVis/stable-diffusion-v1-4",
"Stable Diffusion v1.5": "runwayml/stable-diffusion-v1-5",
"Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1",
# 你可以根据需要添加更多模型
}
# 示例图片的URL列表
SAMPLE_IMAGES = {
"风景": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/landscape.jpg",
"人像": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/portrait.jpg",
"动物": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/animal.jpg",
}
# 使用缓存来存储已加载的模型,以避免重复加载
model_cache = {}
def load_model(model_name):
if model_name in model_cache:
return model_cache[model_name]
else:
model_id = AVAILABLE_MODELS[model_name]
pipe = StableDiffusionPipeline.from_pretrained(
model_id,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32
)
pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu")
model_cache[model_name] = pipe
return pipe
def process_image(model_name, input_image, sample_choice):
# 如果用户选择使用示例图片,则下载示例图片
if sample_choice != "上传图片":
url = SAMPLE_IMAGES.get(sample_choice, SAMPLE_IMAGES["风景"])
response = requests.get(url)
input_image = Image.open(BytesIO(response.content)).convert("RGB")
# 加载所选模型
pipe = load_model(model_name)
# 生成图像(这里以文本提示为例,可以根据实际模型功能调整)
prompt = "A transformed version of the input image."
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"):
generated_image = pipe(prompt=prompt, init_image=input_image, strength=0.8).images[0]
return input_image, generated_image
# 定义 Gradio 接口
def main():
with gr.Blocks() as demo:
gr.Markdown("# Diffusers 扩散模型展示页面")
gr.Markdown("选择一个模型,上传一张图片或选择一个示例图片,然后点击转换按钮查看结果。")
with gr.Row():
model_dropdown = gr.Dropdown(
choices=list(AVAILABLE_MODELS.keys()),
value=list(AVAILABLE_MODELS.keys())[0],
label="选择模型"
)
with gr.Row():
sample_radio = gr.Radio(
choices=["上传图片"] + list(SAMPLE_IMAGES.keys()),
value="上传图片",
label="选择图片来源"
)
with gr.Row():
input_image = gr.Image(
type="pil",
label="上传图片",
visible=False
)
sample_image = gr.Image(
type="pil",
label="示例图片",
visible=False
)
# 根据用户选择显示上传或示例图片
def toggle_image(choice):
return {
"input_image": gr.update(visible=(choice == "上传图片")),
"sample_image": gr.update(visible=(choice != "上传图片"))
}
sample_radio.change(toggle_image, inputs=sample_radio, outputs=[input_image, sample_image])
convert_button = gr.Button("转换")
with gr.Row():
original_output = gr.Image(label="原图")
generated_output = gr.Image(label="生成图")
convert_button.click(
process_image,
inputs=[model_dropdown, input_image, sample_radio],
outputs=[original_output, generated_output]
)
demo.launch(server_port=16006)
if __name__ == "__main__":
main()