ktrndy's picture
Update app.py
16fd361 verified
raw
history blame
13.6 kB
import gradio as gr
import numpy as np
import random
import os
import torch
from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, AutoencoderTiny, DDIMScheduler
from diffusers.utils import load_image
from peft import PeftModel, LoraConfig
from rembg import remove
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
width=512,
height=512,
model_id=model_id_default,
seed=42,
guidance_scale=7.0,
lora_scale=1.0,
num_inference_steps=20,
controlnet_checkbox=False,
controlnet_strength=0.0,
controlnet_mode="edge_detection",
controlnet_image=None,
ip_adapter_checkbox=False,
ip_adapter_scale=0.0,
ip_adapter_image=None,
tiny_vae=False,
ddim=False,
del_background=False,
alpha_matting=False,
alpha_matting_foreground_threshold=240,
alpha_matting_background_threshold=10,
alpha_matting_erode_size=10,
post_process_mask=False,
progress=gr.Progress(track_tqdm=True),
):
if model_id == model_id_default:
ckpt_dir='./model_output'
elif 'base' in model_id:
ckpt_dir='./model_output_distilled_base'
else:
ckpt_dir='./model_output_distilled_small'
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if model_id is None:
raise ValueError("Please specify the base model name or path")
generator = torch.Generator(device).manual_seed(seed)
params = {'prompt': prompt,
'negative_prompt': negative_prompt,
'guidance_scale': guidance_scale,
'num_inference_steps': num_inference_steps,
'width': width,
'height': height,
'generator': generator,
'cross_attention_kwargs': {"scale": lora_scale}
}
if controlnet_checkbox:
if controlnet_mode == "depth_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-depth",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "pose_estimation":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-openpose",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "normal_map":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-normal",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
elif controlnet_mode == "scribbles":
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-scribble",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
else:
controlnet = ControlNetModel.from_pretrained(
"lllyasviel/sd-controlnet-canny",
cache_dir="./models_cache",
torch_dtype=torch_dtype
)
pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id,
controlnet=controlnet,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
params['image'] = controlnet_image
params['controlnet_conditioning_scale'] = float(controlnet_strength)
else:
pipe = StableDiffusionPipeline.from_pretrained(model_id,
torch_dtype=torch_dtype,
safety_checker=None).to(device)
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir)
pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir)
# pipe.unet.add_weighted_adapter(['default'], [lora_scale], 'lora')
# pipe.text_encoder.add_weighted_adapter(['default'], [lora_scale], 'lora')
# pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
# pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
if tiny_vae:
pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch_dtype)
if ddim:
pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
if torch_dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
if ip_adapter_checkbox:
pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin")
pipe.set_ip_adapter_scale(ip_adapter_scale)
params['ip_adapter_image'] = ip_adapter_image
pipe.to(device)
if del_background:
return remove(pipe(**params).images[0],
alpha_matting=alpha_matting,
alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
alpha_matting_background_threshold=alpha_matting_background_threshold,
alpha_matting_erode_size=alpha_matting_erode_size,
post_process_mask=post_process_mask
)
else:
return pipe(**params).images[0]
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
def controlnet_params(show_extra):
return gr.update(visible=show_extra)
with gr.Blocks(css=css, fill_height=True) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image demo")
with gr.Row():
model_id = gr.Dropdown(
label="Model ID",
choices=[model_id_default,
"nota-ai/bk-sdm-v2-base",
"nota-ai/bk-sdm-v2-small"],
value=model_id_default,
max_choices=1
)
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter your negative prompt",
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=30.0,
step=0.1,
value=7.0, # Replace with defaults that work for your model
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=100,
step=1,
value=20, # Replace with defaults that work for your model
)
with gr.Row():
tiny_vae = gr.Checkbox(
label="Use AutoencoderTiny?",
value=False
)
ddim = gr.Checkbox(
label="Use DDIMScheduler?",
value=False
)
with gr.Row():
del_background = gr.Checkbox(
label="Delete background?",
value=False
)
with gr.Column(visible=False) as rembg_params:
alpha_matting = gr.Checkbox(
label="alpha_matting",
value=False
)
with gr.Column(visible=False) as alpha_params:
alpha_matting_foreground_threshold = gr.Slider(
label="alpha_matting_foreground_threshold",
minimum=0,
maximum=255,
step=1,
value=240,
)
alpha_matting_background_threshold = gr.Slider(
label="alpha_matting_background_threshold",
minimum=0,
maximum=255,
step=1,
value=10,
)
alpha_matting_erode_size = gr.Slider(
label="alpha_matting_erode_size",
minimum=0,
maximum=100,
step=1,
value=10,
)
alpha_matting.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=alpha_matting,
outputs=alpha_params
)
post_process_mask = gr.Checkbox(
label="post_process_mask",
value=False
)
del_background.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=del_background,
outputs=rembg_params
)
with gr.Row():
controlnet_checkbox = gr.Checkbox(
label="ControlNet",
value=False
)
with gr.Column(visible=False) as controlnet_params:
controlnet_strength = gr.Slider(
label="ControlNet conditioning scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
controlnet_mode = gr.Dropdown(
label="ControlNet mode",
choices=["edge_detection",
"depth_map",
"pose_estimation",
"normal_map",
"scribbles"],
value="edge_detection",
max_choices=1
)
controlnet_image = gr.Image(
label="ControlNet condition image",
type="pil",
format="png"
)
controlnet_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=controlnet_checkbox,
outputs=controlnet_params
)
with gr.Row():
ip_adapter_checkbox = gr.Checkbox(
label="IPAdapter",
value=False
)
with gr.Column(visible=False) as ip_adapter_params:
ip_adapter_scale = gr.Slider(
label="IPAdapter scale",
minimum=0.0,
maximum=1.0,
step=0.01,
value=1.0,
)
ip_adapter_image = gr.Image(
label="IPAdapter condition image",
type="pil"
)
ip_adapter_checkbox.change(
fn=lambda x: gr.Row.update(visible=x),
inputs=ip_adapter_checkbox,
outputs=ip_adapter_params
)
with gr.Accordion("Optional Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=512, # Replace with defaults that work for your model
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[
prompt,
negative_prompt,
width,
height,
model_id,
seed,
guidance_scale,
lora_scale,
num_inference_steps,
controlnet_checkbox,
controlnet_strength,
controlnet_mode,
controlnet_image,
ip_adapter_checkbox,
ip_adapter_scale,
ip_adapter_image,
tiny_vae,
ddim,
del_background,
alpha_matting,
alpha_matting_foreground_threshold,
alpha_matting_background_threshold,
alpha_matting_erode_size,
post_process_mask,
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()