ktrndy's picture
Update app.py
ce1e24f verified
raw
history blame
6.81 kB
import gradio as gr
import numpy as np
import random
import os
import torch
from diffusers import StableDiffusionPipeline
from peft import PeftModel, LoraConfig
from diffusers import DiffusionPipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
if torch.cuda.is_available():
torch_dtype = torch.float16
else:
torch_dtype = torch.float32
MAX_SEED = np.iinfo(np.int32).max
MAX_IMAGE_SIZE = 1024
def get_lora_sd_pipeline(
ckpt_dir='./model_output',
base_model_name_or_path=model_id_default,
dtype=torch_dtype,
device=device,
adapter_name="pusheen"
):
unet_sub_dir = os.path.join(ckpt_dir, "unet")
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
base_model_name_or_path = config.base_model_name_or_path
if base_model_name_or_path is None:
raise ValueError("Please specify the base model name or path")
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
if os.path.exists(text_encoder_sub_dir):
pipe.text_encoder = PeftModel.from_pretrained(
pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
)
if dtype in (torch.float16, torch.bfloat16):
pipe.unet.half()
pipe.text_encoder.half()
pipe.to(device)
return pipe
# def encode_prompt(prompt, tokenizer, text_encoder):
# text_inputs = tokenizer(
# prompt,
# padding="max_length",
# max_length=tokenizer.model_max_length,
# return_tensors="pt",
# )
# with torch.no_grad():
# if len(text_inputs.input_ids[0]) < tokenizer.model_max_length:
# prompt_embeds = text_encoder(text_inputs.input_ids.to(text_encoder.device))[0]
# else:
# embeds = []
# start = 0
# while start < tokenizer.model_max_length:
# end = start + tokenizer.model_max_length
# part_of_text_inputs = text_inputs.input_ids[0][start:end]
# if len(part_of_text_inputs) < tokenizer.model_max_length:
# part_of_text_inputs = torch.cat([part_of_text_inputs, torch.tensor([tokenizer.pad_token_id] * (tokenizer.model_max_length - len(part_of_text_inputs)))])
# embeds.append(text_encoder(part_of_text_inputs.to(text_encoder.device).unsqueeze(0))[0])
# start += int((8/
# 11)*tokenizer.model_max_length)
# prompt_embeds = torch.mean(torch.stack(embeds, dim=0), dim=0)
# return prompt_embeds
# @spaces.GPU #[uncomment to use ZeroGPU]
def infer(
prompt,
negative_prompt,
width=512,
height=512,
model_id=model_id_default,
seed=42,
guidance_scale=7.0,
lora_scale=1.0,
num_inference_steps=20,
progress=gr.Progress(track_tqdm=True),
):
generator = torch.Generator(device).manual_seed(seed)
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
pipe = pipe.to(device)
pipe.fuse_lora(lora_scale=lora_scale)
pipe.safety_checker = None
# prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
# negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
image = pipe(
prompt=prompt,
negative_prompt=negative_prompt,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
width=width,
height=height,
generator=generator,
).images[0]
return image
css = """
#col-container {
margin: 0 auto;
max-width: 640px;
}
"""
with gr.Blocks(css=css, fill_height=True) as demo:
with gr.Column(elem_id="col-container"):
gr.Markdown(" # Text-to-Image demo")
with gr.Row():
model_id = gr.Textbox(
label="Model ID",
max_lines=1,
placeholder="Enter model id",
value=model_id_default,
)
prompt = gr.Textbox(
label="Prompt",
max_lines=1,
placeholder="Enter your prompt",
)
negative_prompt = gr.Textbox(
label="Negative prompt",
max_lines=1,
placeholder="Enter your negative prompt",
)
with gr.Row():
seed = gr.Number(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
guidance_scale = gr.Slider(
label="Guidance scale",
minimum=0.0,
maximum=10.0,
step=0.1,
value=7.0, # Replace with defaults that work for your model
)
with gr.Row():
lora_scale = gr.Slider(
label="LoRA scale",
minimum=0.0,
maximum=1.0,
step=0.1,
value=1.0,
)
num_inference_steps = gr.Slider(
label="Number of inference steps",
minimum=1,
maximum=50,
step=1,
value=20, # Replace with defaults that work for your model
)
with gr.Accordion("Optional Settings", open=False):
with gr.Row():
width = gr.Slider(
label="Width",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024, # Replace with defaults that work for your model
)
height = gr.Slider(
label="Height",
minimum=256,
maximum=MAX_IMAGE_SIZE,
step=32,
value=1024, # Replace with defaults that work for your model
)
run_button = gr.Button("Run", scale=0, variant="primary")
result = gr.Image(label="Result", show_label=False)
gr.on(
triggers=[run_button.click],
fn=infer,
inputs=[
prompt,
negative_prompt,
width,
height,
model_id,
seed,
guidance_scale,
num_inference_steps,
lora_scale
],
outputs=[result],
)
if __name__ == "__main__":
demo.launch()