jayeshprajapati9693's picture
Update app.py
58fd2d0 verified
raw
history blame
10.2 kB
import os
import time
import random
import tempfile
import torch
import gradio as gr
from PIL import Image
import spaces
from gradio import processing_utils, utils
from diffusers import (
AutoencoderKL,
ControlNetModel,
StableDiffusionControlNetPipeline,
StableDiffusionControlNetImg2ImgPipeline,
StableDiffusionLatentUpscalePipeline,
DPMSolverMultistepScheduler,
EulerDiscreteScheduler,
)
from share_btn import community_icon_html, loading_icon_html, share_js
import user_history
from illusion_style import css
# -----------------------------
# Device & dtype (GPU/CPU auto)
# -----------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.float16 if device == "cuda" else torch.float32
# -----------------------------
# Base / ControlNet models
# -----------------------------
BASE_MODEL = "SG161222/Realistic_Vision_V5.1_noVAE"
VAE_ID = "stabilityai/sd-vae-ft-mse"
CONTROLNET_ID = "monster-labs/control_v1p_sd15_qrcode_monster"
# -----------------------------
# Load components
# -----------------------------
vae = AutoencoderKL.from_pretrained(VAE_ID, torch_dtype=dtype)
controlnet = ControlNetModel.from_pretrained(CONTROLNET_ID, torch_dtype=dtype)
# โš ๏ธ safety checker & clip feature extractor removed
main_pipe = StableDiffusionControlNetPipeline.from_pretrained(
BASE_MODEL,
controlnet=controlnet,
vae=vae,
safety_checker=None, # <= important
feature_extractor=None, # <= important
torch_dtype=dtype,
)
main_pipe = main_pipe.to(device)
# Img2Img pipe reusing components
image_pipe = StableDiffusionControlNetImg2ImgPipeline(**main_pipe.components)
image_pipe = image_pipe.to(device)
# -----------------------------
# Sampler map
# -----------------------------
SAMPLER_MAP = {
"DPM++ Karras SDE": lambda config: DPMSolverMultistepScheduler.from_config(
config, use_karras=True, algorithm_type="sde-dpmsolver++"
),
"Euler": lambda config: EulerDiscreteScheduler.from_config(config),
}
# -----------------------------
# Helpers
# -----------------------------
def center_crop_resize(img: Image.Image, output_size=(512, 512)):
width, height = img.size
new_dim = min(width, height)
left = (width - new_dim) / 2
top = (height - new_dim) / 2
right = (width + new_dim) / 2
bottom = (height + new_dim) / 2
img = img.crop((left, top, right, bottom))
img = img.resize(output_size)
return img
def common_upscale(samples, width, height, upscale_method, crop=False):
if crop == "center":
old_w = samples.shape[3]
old_h = samples.shape[2]
old_aspect = old_w / old_h
new_aspect = width / height
x = 0
y = 0
if old_aspect > new_aspect:
x = round((old_w - old_w * (new_aspect / old_aspect)) / 2)
elif old_aspect < new_aspect:
y = round((old_h - old_h * (old_aspect / new_aspect)) / 2)
s = samples[:, :, y : old_h - y, x : old_w - x]
else:
s = samples
return torch.nn.functional.interpolate(s, size=(height, width), mode=upscale_method)
def upscale(samples, upscale_method, scale_by):
width = round(samples["images"].shape[3] * scale_by)
height = round(samples["images"].shape[2] * scale_by)
s = common_upscale(samples["images"], width, height, upscale_method, "disabled")
return s
def check_inputs(prompt: str, control_image: Image.Image):
if control_image is None:
raise gr.Error("Please select or upload an Input Illusion")
if not prompt:
raise gr.Error("Prompt is required")
# -----------------------------
# Inference
# -----------------------------
@spaces.GPU
def inference(
control_image: Image.Image,
prompt: str,
negative_prompt: str,
guidance_scale: float = 8.0,
controlnet_conditioning_scale: float = 1.0,
control_guidance_start: float = 1.0,
control_guidance_end: float = 1.0,
upscaler_strength: float = 0.5,
seed: int = -1,
sampler: str = "DPM++ Karras SDE",
progress = gr.Progress(track_tqdm=True),
profile: gr.OAuthProfile | None = None,
):
start_time = time.time()
control_image_small = center_crop_resize(control_image, (512, 512))
control_image_large = center_crop_resize(control_image, (1024, 1024))
main_pipe.scheduler = SAMPLER_MAP[sampler](main_pipe.scheduler.config)
my_seed = random.randint(0, 2**32 - 1) if seed == -1 else int(seed)
generator = torch.Generator(device=device).manual_seed(my_seed)
# First pass -> latents
out = main_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
image=control_image_small,
guidance_scale=float(guidance_scale),
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
generator=generator,
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
num_inference_steps=15,
output_type="latent",
)
# Upscale latents
upscaled_latents = upscale(out, "nearest-exact", 2)
# Second pass -> image
out_image = image_pipe(
prompt=prompt,
negative_prompt=negative_prompt,
control_image=control_image_large,
image=upscaled_latents,
guidance_scale=float(guidance_scale),
generator=generator,
num_inference_steps=20,
strength=float(upscaler_strength),
control_guidance_start=float(control_guidance_start),
control_guidance_end=float(control_guidance_end),
controlnet_conditioning_scale=float(controlnet_conditioning_scale),
)
# Save history
user_history.save_image(
label=prompt,
image=out_image["images"][0],
profile=profile,
metadata={
"prompt": prompt,
"negative_prompt": negative_prompt,
"guidance_scale": guidance_scale,
"controlnet_conditioning_scale": controlnet_conditioning_scale,
"control_guidance_start": control_guidance_start,
"control_guidance_end": control_guidance_end,
"upscaler_strength": upscaler_strength,
"seed": my_seed,
"sampler": sampler,
},
)
return out_image["images"][0], gr.update(visible=True), gr.update(visible=True), my_seed
# -----------------------------
# UI
# -----------------------------
with gr.Blocks() as app:
gr.Markdown(
'''
<div style="text-align: center;">
<h1>Illusion Diffusion HQ ๐ŸŒ€</h1>
<p style="font-size:16px;">Generate high-quality illusion artwork with Stable Diffusion + ControlNet</p>
<p>A space by AP with contributions from the community.</p>
<p>This uses <a href="https://huggingface.co/monster-labs/control_v1p_sd15_qrcode_monster">Monster Labs QR ControlNet</a>.</p>
</div>
'''
)
state_img_input = gr.State()
state_img_output = gr.State()
with gr.Row():
with gr.Column():
control_image = gr.Image(label="Input Illusion", type="pil", elem_id="control_image")
controlnet_conditioning_scale = gr.Slider(minimum=0.0, maximum=5.0, step=0.01, value=0.8, label="Illusion strength", elem_id="illusion_strength", info="ControlNet conditioning scale")
gr.Examples(
examples=["checkers.png", "checkers_mid.jpg", "pattern.png", "ultra_checkers.png", "spiral.jpeg", "funky.jpeg"],
inputs=control_image
)
prompt = gr.Textbox(label="Prompt", elem_id="prompt", info="Type what you want to generate", placeholder="Medieval village scene with busy streets and a castle in the distance")
negative_prompt = gr.Textbox(label="Negative Prompt", info="What you do NOT want", value="low quality, blurry", elem_id="negative_prompt")
with gr.Accordion(label="Advanced Options", open=False):
guidance_scale = gr.Slider(minimum=0.0, maximum=50.0, step=0.25, value=7.5, label="Guidance Scale")
sampler = gr.Dropdown(choices=list(SAMPLER_MAP.keys()), value="Euler", label="Sampler")
control_start = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=0.0, label="Start of ControlNet")
control_end = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="End of ControlNet")
strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.1, value=1.0, label="Strength of the upscaler")
seed = gr.Slider(minimum=-1, maximum=9999999999, step=1, value=-1, label="Seed", info="-1 = random")
used_seed = gr.Number(label="Last seed used", interactive=False)
run_btn = gr.Button("Run")
with gr.Column():
result_image = gr.Image(label="Illusion Diffusion Output", interactive=False, elem_id="output")
with gr.Group(elem_id="share-btn-container", visible=False) as share_group:
community_icon = gr.HTML(community_icon_html)
loading_icon = gr.HTML(loading_icon_html)
share_button = gr.Button("Share to community", elem_id="share-btn")
# Wire up
prompt.submit(
check_inputs,
inputs=[prompt, control_image],
queue=False
).success(
inference,
inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
outputs=[result_image, result_image, share_group, used_seed]
)
run_btn.click(
check_inputs,
inputs=[prompt, control_image],
queue=False
).success(
inference,
inputs=[control_image, prompt, negative_prompt, guidance_scale, controlnet_conditioning_scale, control_start, control_end, strength, seed, sampler],
outputs=[result_image, result_image, share_group, used_seed]
)
share_button.click(None, [], [], js=share_js)
with gr.Blocks(css=css) as app_with_history:
with gr.Tab("Demo"):
app.render()
with gr.Tab("Past generations"):
user_history.render()
app_with_history.queue(max_size=20, api_open=False)
if __name__ == "__main__":
app_with_history.launch(max_threads=400)