ktrndy commited on
Commit
ce1e24f
·
verified ·
1 Parent(s): f60a1b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -7
app.py CHANGED
@@ -8,7 +8,7 @@ from peft import PeftModel, LoraConfig
8
  from diffusers import DiffusionPipeline
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5" # Replace to the model you would like to use
12
 
13
  if torch.cuda.is_available():
14
  torch_dtype = torch.float16
@@ -20,19 +20,28 @@ MAX_IMAGE_SIZE = 1024
20
 
21
 
22
  def get_lora_sd_pipeline(
23
- ckpt_dir='./output',
24
  base_model_name_or_path=model_id_default,
25
  dtype=torch_dtype,
26
- device=device
 
27
  ):
28
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
29
-
 
 
 
 
30
  if base_model_name_or_path is None:
31
  raise ValueError("Please specify the base model name or path")
32
 
33
  pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
34
- pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, torch_dtype=torch_dtype)
35
- pipe.text_encoder = PeftModel.from_pretrained(pipe.text_encoder, base_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype)
 
 
 
 
36
 
37
  if dtype in (torch.float16, torch.bfloat16):
38
  pipe.unet.half()
@@ -84,7 +93,8 @@ def infer(
84
  generator = torch.Generator(device).manual_seed(seed)
85
  pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
86
  pipe = pipe.to(device)
87
- # pipe.fuse_lora(lora_scale=lora_scale)
 
88
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
89
  # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
90
 
 
8
  from diffusers import DiffusionPipeline
9
 
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model_id_default = "stable-diffusion-v1-5/stable-diffusion-v1-5"
12
 
13
  if torch.cuda.is_available():
14
  torch_dtype = torch.float16
 
20
 
21
 
22
  def get_lora_sd_pipeline(
23
+ ckpt_dir='./model_output',
24
  base_model_name_or_path=model_id_default,
25
  dtype=torch_dtype,
26
+ device=device,
27
+ adapter_name="pusheen"
28
  ):
29
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
30
+ text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
31
+ if os.path.exists(text_encoder_sub_dir) and base_model_name_or_path is None:
32
+ config = LoraConfig.from_pretrained(text_encoder_sub_dir)
33
+ base_model_name_or_path = config.base_model_name_or_path
34
+
35
  if base_model_name_or_path is None:
36
  raise ValueError("Please specify the base model name or path")
37
 
38
  pipe = StableDiffusionPipeline.from_pretrained(base_model_name_or_path, torch_dtype=dtype).to(device)
39
+ pipe.unet = PeftModel.from_pretrained(pipe.unet, unet_sub_dir, adapter_name=adapter_name)
40
+
41
+ if os.path.exists(text_encoder_sub_dir):
42
+ pipe.text_encoder = PeftModel.from_pretrained(
43
+ pipe.text_encoder, text_encoder_sub_dir, adapter_name=adapter_name
44
+ )
45
 
46
  if dtype in (torch.float16, torch.bfloat16):
47
  pipe.unet.half()
 
93
  generator = torch.Generator(device).manual_seed(seed)
94
  pipe = get_lora_sd_pipeline(base_model_name_or_path=model_id)
95
  pipe = pipe.to(device)
96
+ pipe.fuse_lora(lora_scale=lora_scale)
97
+ pipe.safety_checker = None
98
  # prompt_embeds = encode_prompt(prompt, pipe.tokenizer, pipe.text_encoder)
99
  # negative_prompt_embeds = encode_prompt(negative_prompt, pipe.tokenizer, pipe.text_encoder)
100