File size: 4,509 Bytes
63e9337 1600949 cdd67f8 1600949 df905b7 1600949 6994be4 1600949 d3bc7ba f8ee0c1 1600949 6994be4 1600949 6994be4 1600949 6994be4 1600949 df905b7 1600949 6994be4 1600949 6994be4 1600949 6994be4 1600949 6994be4 1600949 df905b7 9a76fcf 1600949 6994be4 1600949 6994be4 1600949 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import spaces
import base64
from io import BytesIO
import gradio as gr
import PIL.Image
import torch
from diffusers import StableDiffusionPipeline, AutoencoderKL, AutoencoderTiny
from peft import PeftModel
device = "cuda"
weight_type = torch.float16
pipe = StableDiffusionPipeline.from_pretrained("IDKiro/sdxs-512-dreamshaper")
pipe.unet = PeftModel.from_pretrained(pipe.unet, "IDKiro/sdxs-512-dreamshaper-anime")
pipe.unet.merge_and_unload()
pipe.to(device, dtype=weight_type)
vae_tiny = AutoencoderTiny.from_pretrained(
"IDKiro/sdxs-512-dreamshaper", subfolder="vae"
)
vae_tiny.to(device, dtype=weight_type)
vae_large = AutoencoderKL.from_pretrained(
"IDKiro/sdxs-512-dreamshaper", subfolder="vae_large"
)
vae_tiny.to(device, dtype=weight_type)
def pil_image_to_data_url(img, format="PNG"):
buffered = BytesIO()
img.save(buffered, format=format)
img_str = base64.b64encode(buffered.getvalue()).decode()
return f"data:image/{format.lower()};base64,{img_str}"
@spaces.GPU
def run(
prompt: str,
device_type="GPU",
vae_type=None,
param_dtype="torch.float16",
) -> PIL.Image.Image:
if vae_type == "tiny vae":
pipe.vae = vae_tiny
elif vae_type == "large vae":
pipe.vae = vae_large
if device_type == "CPU":
device = "cpu"
param_dtype = "torch.float32"
else:
device = "cuda"
pipe.to(
torch_device=device,
torch_dtype=torch.float16 if param_dtype == "torch.float16" else torch.float32,
)
result = pipe(
prompt=prompt,
guidance_scale=0.0,
num_inference_steps=1,
output_type="pil",
).images[0]
result_url = pil_image_to_data_url(result)
return (result, result_url)
examples = [
"A photo of beautiful mountain with realistic sunset and blue lake, highly detailed, masterpiece",
]
with gr.Blocks(css="style.css") as demo:
gr.Markdown("# SDXS-512-DreamShaper-Anime")
gr.Markdown("[SDXS: Real-Time One-Step Latent Diffusion Models with Image Conditions](https://arxiv.org/abs/2403.16627) | [GitHub](https://github.com/IDKiro/sdxs)")
with gr.Group():
with gr.Row():
with gr.Column(min_width=685):
with gr.Row():
prompt = gr.Text(
label="Prompt",
show_label=False,
max_lines=1,
placeholder="Enter your prompt",
container=False,
)
run_button = gr.Button("Run", scale=0)
device_choices = ["GPU", "CPU"]
device_type = gr.Radio(
device_choices,
label="Device",
value=device_choices[0],
interactive=True,
info="Thanks to the community for the GPU!",
)
vae_choices = ["tiny vae", "large vae"]
vae_type = gr.Radio(
vae_choices,
label="Image Decoder Type",
value=vae_choices[0],
interactive=True,
info="To save GPU memory, use tiny vae. For better quality, use large vae.",
)
dtype_choices = ["torch.float16", "torch.float32"]
param_dtype = gr.Radio(
dtype_choices,
label="torch.weight_type",
value=dtype_choices[0],
interactive=True,
info="To save GPU memory, use torch.float16. For better quality, use torch.float32.",
)
download_output = gr.Button(
"Download output", elem_id="download_output"
)
with gr.Column(min_width=512):
result = gr.Image(
label="Result",
height=512,
width=512,
elem_id="output_image",
show_label=False,
show_download_button=True,
)
gr.Examples(examples=examples, inputs=prompt, outputs=result, fn=run)
demo.load(None, None, None)
inputs = [prompt, device_type, vae_type, param_dtype]
outputs = [result, download_output]
prompt.submit(fn=run, inputs=inputs, outputs=outputs)
run_button.click(fn=run, inputs=inputs, outputs=outputs)
if __name__ == "__main__":
demo.queue().launch(debug=True)
|