|
|
|
|
|
|
|
|
|
import gradio as gr |
|
from diffusers import StableDiffusionPipeline |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
|
|
|
|
AVAILABLE_MODELS = { |
|
"Stable Diffusion v1.4": "CompVis/stable-diffusion-v1-4", |
|
"Stable Diffusion v1.5": "runwayml/stable-diffusion-v1-5", |
|
"Stable Diffusion 2.1": "stabilityai/stable-diffusion-2-1", |
|
|
|
} |
|
|
|
|
|
SAMPLE_IMAGES = { |
|
"风景": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/landscape.jpg", |
|
"人像": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/portrait.jpg", |
|
"动物": "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/samples/animal.jpg", |
|
} |
|
|
|
|
|
model_cache = {} |
|
|
|
def load_model(model_name): |
|
if model_name in model_cache: |
|
return model_cache[model_name] |
|
else: |
|
model_id = AVAILABLE_MODELS[model_name] |
|
pipe = StableDiffusionPipeline.from_pretrained( |
|
model_id, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32 |
|
) |
|
pipe = pipe.to("cuda") if torch.cuda.is_available() else pipe.to("cpu") |
|
model_cache[model_name] = pipe |
|
return pipe |
|
|
|
def process_image(model_name, input_image, sample_choice): |
|
|
|
if sample_choice != "上传图片": |
|
url = SAMPLE_IMAGES.get(sample_choice, SAMPLE_IMAGES["风景"]) |
|
response = requests.get(url) |
|
input_image = Image.open(BytesIO(response.content)).convert("RGB") |
|
|
|
|
|
pipe = load_model(model_name) |
|
|
|
|
|
prompt = "A transformed version of the input image." |
|
|
|
with torch.autocast("cuda" if torch.cuda.is_available() else "cpu"): |
|
generated_image = pipe(prompt=prompt, init_image=input_image, strength=0.8).images[0] |
|
|
|
return input_image, generated_image |
|
|
|
|
|
def main(): |
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Diffusers 扩散模型展示页面") |
|
gr.Markdown("选择一个模型,上传一张图片或选择一个示例图片,然后点击转换按钮查看结果。") |
|
|
|
with gr.Row(): |
|
model_dropdown = gr.Dropdown( |
|
choices=list(AVAILABLE_MODELS.keys()), |
|
value=list(AVAILABLE_MODELS.keys())[0], |
|
label="选择模型" |
|
) |
|
|
|
with gr.Row(): |
|
sample_radio = gr.Radio( |
|
choices=["上传图片"] + list(SAMPLE_IMAGES.keys()), |
|
value="上传图片", |
|
label="选择图片来源" |
|
) |
|
|
|
with gr.Row(): |
|
input_image = gr.Image( |
|
type="pil", |
|
label="上传图片", |
|
visible=False |
|
) |
|
sample_image = gr.Image( |
|
type="pil", |
|
label="示例图片", |
|
visible=False |
|
) |
|
|
|
|
|
def toggle_image(choice): |
|
return { |
|
"input_image": gr.update(visible=(choice == "上传图片")), |
|
"sample_image": gr.update(visible=(choice != "上传图片")) |
|
} |
|
|
|
sample_radio.change(toggle_image, inputs=sample_radio, outputs=[input_image, sample_image]) |
|
|
|
convert_button = gr.Button("转换") |
|
|
|
with gr.Row(): |
|
original_output = gr.Image(label="原图") |
|
generated_output = gr.Image(label="生成图") |
|
|
|
convert_button.click( |
|
process_image, |
|
inputs=[model_dropdown, input_image, sample_radio], |
|
outputs=[original_output, generated_output] |
|
) |
|
|
|
demo.launch(server_port=16006) |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|