Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,7 +8,7 @@ from peft import PeftModel, LoraConfig
|
|
8 |
from diffusers import DiffusionPipeline
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
-
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
12 |
|
13 |
if torch.cuda.is_available():
|
14 |
torch_dtype = torch.float16
|
@@ -20,19 +20,28 @@ MAX_IMAGE_SIZE = 1024
|
|
20 |
|
21 |
|
22 |
def get_lora_sd_pipeline(
|
23 |
-
ckpt_dir='./
|
24 |
base_model_name_or_path=model_id_default,
|
25 |
dtype=torch_dtype,
|
26 |
-
device=device
|
|
|
27 |
):
|
28 |
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
29 |
-
|
|
|
|
|
|
|
|
|
30 |
if base_model_name_or_path is None:
|
31 |
raise ValueError("Please specify the base model name or path")
|
32 |
|
33 |
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
|
34 |
-
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir,
|
35 |
-
|
|
|
|
|
|
|
|
|
36 |
|
37 |
if dtype in (torch.float16, torch.bfloat16):
|
38 |
pipe.unet.half()
|
@@ -84,7 +93,8 @@ def infer(
|
|
84 |
generator = torch.Generator(device).manual_seed(seed)
|
85 |
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
|
86 |
pipe = pipe.to(device)
|
87 |
-
|
|
|
88 |
# prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
|
89 |
# negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
90 |
|
|
|
8 |
from diffusers import DiffusionPipeline
|
9 |
|
10 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
+
model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
|
12 |
|
13 |
if torch.cuda.is_available():
|
14 |
torch_dtype = torch.float16
|
|
|
20 |
|
21 |
|
22 |
def get_lora_sd_pipeline(
|
23 |
+
ckpt_dir='./model_output',
|
24 |
base_model_name_or_path=model_id_default,
|
25 |
dtype=torch_dtype,
|
26 |
+
device=device,
|
27 |
+
adapter_name="pusheen"
|
28 |
):
|
29 |
unet_sub_dir = os.path.join(ckpt_dir, "unet")
|
30 |
+
text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
|
31 |
+
if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
|
32 |
+
config = LoraConfig.from_pretrained(text_encoder_sub_dir)
|
33 |
+
base_model_name_or_path = config.base_model_name_or_path
|
34 |
+
|
35 |
if base_model_name_or_path is None:
|
36 |
raise ValueError("Please specify the base model name or path")
|
37 |
|
38 |
pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
|
39 |
+
pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
|
40 |
+
|
41 |
+
if os.path.exists(text_encoder_sub_dir):
|
42 |
+
pipe.text_encoder = PeftModel.from_pretrained(
|
43 |
+
pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
|
44 |
+
)
|
45 |
|
46 |
if dtype in (torch.float16, torch.bfloat16):
|
47 |
pipe.unet.half()
|
|
|
93 |
generator = torch.Generator(device).manual_seed(seed)
|
94 |
pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
|
95 |
pipe = pipe.to(device)
|
96 |
+
pipe.fuse_lora(lora_scale=lora_scale)
|
97 |
+
pipe.safety_checker = None
|
98 |
# prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
|
99 |
# negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
|
100 |
|