import gradio as gr import numpy as np import random import os import torch from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, AutoencoderTiny, DDIMScheduler from diffusers.utils import load_image from peft import PeftModel, LoraConfig from rembg import remove device = "cuda" if torch.cuda.is_available() else "cpu" model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5" if torch.cuda.is_available(): torch_dtype = torch.float16 else: torch_dtype = torch.float32 MAX_SEED = np.iinfo(np.int32).max MAX_IMAGE_SIZE = 1024 # @spaces.GPU #[uncomment to use ZeroGPU] def infer( prompt, negative_prompt, width=512, height=512, model_id=model_id_default, seed=42, guidance_scale=7.0, lora_scale=1.0, num_inference_steps=20, controlnet_checkbox=False, controlnet_strength=0.0, controlnet_mode="edge_detection", controlnet_image=None, ip_adapter_checkbox=False, ip_adapter_scale=0.0, ip_adapter_image=None, tiny_vae=False, ddim=False, del_background=False, alpha_matting=False, alpha_matting_foreground_threshold=240, alpha_matting_background_threshold=10, alpha_matting_erode_size=10, post_process_mask=False, progress=gr.Progress(track_tqdm=True), ): if model_id == model_id_default: ckpt_dir='./model_output' elif 'base' in model_id: ckpt_dir='./model_output_distilled_base' else: ckpt_dir='./model_output_distilled_small' unet_sub_dir = os.path.join(ckpt_dir, "unet") text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder") if model_id is None: raise ValueError("Please specify the base model name or path") generator = torch.Generator(device).manual_seed(seed) params = {'prompt': prompt, 'negative_prompt': negative_prompt, 'guidance_scale': guidance_scale, 'num_inference_steps': num_inference_steps, 'width': width, 'height': height, 'generator': generator, 'cross_attention_kwargs': {"scale": lora_scale} } if controlnet_checkbox: if controlnet_mode == "depth_map": controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-depth", cache_dir="./models_cache", torch_dtype=torch_dtype ) elif controlnet_mode == "pose_estimation": controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-openpose", cache_dir="./models_cache", torch_dtype=torch_dtype ) elif controlnet_mode == "normal_map": controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-normal", cache_dir="./models_cache", torch_dtype=torch_dtype ) elif controlnet_mode == "scribbles": controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-scribble", cache_dir="./models_cache", torch_dtype=torch_dtype ) else: controlnet = ControlNetModel.from_pretrained( "lllyasviel/sd-controlnet-canny", cache_dir="./models_cache", torch_dtype=torch_dtype ) pipe = StableDiffusionControlNetPipeline.from_pretrained(model_id, controlnet=controlnet, torch_dtype=torch_dtype, safety_checker=None).to(device) params['image'] = controlnet_image params['controlnet_conditioning_scale'] = float(controlnet_strength) else: pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch_dtype, safety_checker=None).to(device) pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir) pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, text_encoder_sub_dir) # pipe.unet.add_weighted_adapter(['default'], [lora_scale], 'lora') # pipe.text_encoder.add_weighted_adapter(['default'], [lora_scale], 'lora') # pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()}) # pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()}) if tiny_vae: pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch_dtype) if ddim: pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config) if torch_dtype in (torch.float16, torch.bfloat16): pipe.unet.half() pipe.text_encoder.half() if ip_adapter_checkbox: pipe.load_ip_adapter("h94/IP-Adapter", subfolder="models", weight_name="ip-adapter-plus_sd15.bin") pipe.set_ip_adapter_scale(ip_adapter_scale) params['ip_adapter_image'] = ip_adapter_image pipe.to(device) if del_background: return remove(pipe(**params).images[0], alpha_matting=alpha_matting, alpha_matting_foreground_threshold=alpha_matting_foreground_threshold, alpha_matting_background_threshold=alpha_matting_background_threshold, alpha_matting_erode_size=alpha_matting_erode_size, post_process_mask=post_process_mask ) else: return pipe(**params).images[0] css = """ #col-container { margin: 0 auto; max-width: 640px; } """ def controlnet_params(show_extra): return gr.update(visible=show_extra) with gr.Blocks(css=css, fill_height=True) as demo: with gr.Column(elem_id="col-container"): gr.Markdown(" # Text-to-Image demo") with gr.Row(): model_id = gr.Dropdown( label="Model ID", choices=[model_id_default, "nota-ai/bk-sdm-v2-base", "nota-ai/bk-sdm-v2-small"], value=model_id_default, max_choices=1 ) prompt = gr.Textbox( label="Prompt", max_lines=1, placeholder="Enter your prompt", ) negative_prompt = gr.Textbox( label="Negative prompt", max_lines=1, placeholder="Enter your negative prompt", ) with gr.Row(): seed = gr.Number( label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42, ) guidance_scale = gr.Slider( label="Guidance scale", minimum=0.0, maximum=30.0, step=0.1, value=7.0, # Replace with defaults that work for your model ) with gr.Row(): lora_scale = gr.Slider( label="LoRA scale", minimum=0.0, maximum=1.0, step=0.01, value=1.0, ) num_inference_steps = gr.Slider( label="Number of inference steps", minimum=1, maximum=100, step=1, value=20, # Replace with defaults that work for your model ) with gr.Row(): tiny_vae = gr.Checkbox( label="Use AutoencoderTiny?", value=False ) ddim = gr.Checkbox( label="Use DDIMScheduler?", value=False ) with gr.Row(): del_background = gr.Checkbox( label="Delete background?", value=False ) with gr.Column(visible=False) as rembg_params: alpha_matting = gr.Checkbox( label="alpha_matting", value=False ) with gr.Column(visible=False) as alpha_params: alpha_matting_foreground_threshold = gr.Slider( label="alpha_matting_foreground_threshold", minimum=0, maximum=255, step=1, value=240, ) alpha_matting_background_threshold = gr.Slider( label="alpha_matting_background_threshold", minimum=0, maximum=255, step=1, value=10, ) alpha_matting_erode_size = gr.Slider( label="alpha_matting_erode_size", minimum=0, maximum=100, step=1, value=10, ) alpha_matting.change( fn=lambda x: gr.Row.update(visible=x), inputs=alpha_matting, outputs=alpha_params ) post_process_mask = gr.Checkbox( label="post_process_mask", value=False ) del_background.change( fn=lambda x: gr.Row.update(visible=x), inputs=del_background, outputs=rembg_params ) with gr.Row(): controlnet_checkbox = gr.Checkbox( label="ControlNet", value=False ) with gr.Column(visible=False) as controlnet_params: controlnet_strength = gr.Slider( label="ControlNet conditioning scale", minimum=0.0, maximum=1.0, step=0.01, value=1.0, ) controlnet_mode = gr.Dropdown( label="ControlNet mode", choices=["edge_detection", "depth_map", "pose_estimation", "normal_map", "scribbles"], value="edge_detection", max_choices=1 ) controlnet_image = gr.Image( label="ControlNet condition image", type="pil", format="png" ) controlnet_checkbox.change( fn=lambda x: gr.Row.update(visible=x), inputs=controlnet_checkbox, outputs=controlnet_params ) with gr.Row(): ip_adapter_checkbox = gr.Checkbox( label="IPAdapter", value=False ) with gr.Column(visible=False) as ip_adapter_params: ip_adapter_scale = gr.Slider( label="IPAdapter scale", minimum=0.0, maximum=1.0, step=0.01, value=1.0, ) ip_adapter_image = gr.Image( label="IPAdapter condition image", type="pil" ) ip_adapter_checkbox.change( fn=lambda x: gr.Row.update(visible=x), inputs=ip_adapter_checkbox, outputs=ip_adapter_params ) with gr.Accordion("Optional Settings", open=False): with gr.Row(): width = gr.Slider( label="Width", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, # Replace with defaults that work for your model ) height = gr.Slider( label="Height", minimum=256, maximum=MAX_IMAGE_SIZE, step=32, value=512, # Replace with defaults that work for your model ) run_button = gr.Button("Run", scale=0, variant="primary") result = gr.Image(label="Result", show_label=False) gr.on( triggers=[run_button.click], fn=infer, inputs=[ prompt, negative_prompt, width, height, model_id, seed, guidance_scale, lora_scale, num_inference_steps, controlnet_checkbox, controlnet_strength, controlnet_mode, controlnet_image, ip_adapter_checkbox, ip_adapter_scale, ip_adapter_image, tiny_vae, ddim, del_background, alpha_matting, alpha_matting_foreground_threshold, alpha_matting_background_threshold, alpha_matting_erode_size, post_process_mask, ], outputs=[result], ) if __name__ == "__main__": demo.launch()