ktrndy commited on
Commit
a643bb2
·
verified ·
1 Parent(s): 4c555ac

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +94 -9
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import random
4
  import os
5
  import torch
6
- from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline
7
  from diffusers.utils import load_image
8
  from peft import PeftModel, LoraConfig
9
  from rembg import remove
@@ -39,10 +39,26 @@ def infer(
39
  ip_adapter_checkbox=False,
40
  ip_adapter_scale=0.0,
41
  ip_adapter_image=None,
 
 
 
 
42
  del_background=False,
 
 
 
 
 
 
43
  progress=gr.Progress(track_tqdm=True),
44
  ):
45
- ckpt_dir='./model_output'
 
 
 
 
 
 
46
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
47
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
48
 
@@ -106,6 +122,12 @@ def infer(
106
 
107
  pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
108
  pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
 
 
 
 
 
 
109
 
110
  if torch_dtype in (torch.float16, torch.bfloat16):
111
  pipe.unet.half()
@@ -119,7 +141,13 @@ def infer(
119
  pipe.to(device)
120
 
121
  if del_background:
122
- return remove(pipe(**params).images[0])
 
 
 
 
 
 
123
  else:
124
  return pipe(**params).images[0]
125
 
@@ -139,12 +167,15 @@ with gr.Blocks(css=css, fill_height=True) as demo:
139
  gr.Markdown(" # Text-to-Image demo")
140
 
141
  with gr.Row():
142
- model_id = gr.Textbox(
143
- label="Model ID",
144
- max_lines=1,
145
- placeholder="Enter model id",
146
- value=model_id_default,
147
- )
 
 
 
148
 
149
  prompt = gr.Textbox(
150
  label="Prompt",
@@ -190,11 +221,58 @@ with gr.Blocks(css=css, fill_height=True) as demo:
190
  step=1,
191
  value=20, # Replace with defaults that work for your model
192
  )
 
 
 
 
 
 
 
 
 
 
193
  with gr.Row():
194
  del_background = gr.Checkbox(
195
  label="Delete background?",
196
  value=False
197
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  with gr.Row():
199
  controlnet_checkbox = gr.Checkbox(
200
  label="ControlNet",
@@ -294,7 +372,14 @@ with gr.Blocks(css=css, fill_height=True) as demo:
294
  ip_adapter_checkbox,
295
  ip_adapter_scale,
296
  ip_adapter_image,
 
 
297
  del_background,
 
 
 
 
 
298
  ],
299
  outputs=[result],
300
  )
 
3
  import random
4
  import os
5
  import torch
6
+ from diffusers import StableDiffusionPipeline, ControlNetModel, StableDiffusionControlNetPipeline, AutoencoderTiny, DDIMScheduler
7
  from diffusers.utils import load_image
8
  from peft import PeftModel, LoraConfig
9
  from rembg import remove
 
39
  ip_adapter_checkbox=False,
40
  ip_adapter_scale=0.0,
41
  ip_adapter_image=None,
42
+
43
+ tiny_vae=False,
44
+ ddim=False,
45
+
46
  del_background=False,
47
+ alpha_matting=False,
48
+ alpha_matting_foreground_threshold=240,
49
+ alpha_matting_background_threshold=10,
50
+ alpha_matting_erode_size=10,
51
+ post_process_mask=False,
52
+
53
  progress=gr.Progress(track_tqdm=True),
54
  ):
55
+ if model_id == model_id_default:
56
+ ckpt_dir='./model_output'
57
+ elif 'base' in model_id:
58
+ ckpt_dir='./model_output_distilled_base'
59
+ else:
60
+ ckpt_dir='./model_output_distilled_small'
61
+
62
  unet_sub_dir = os.path.join(ckpt_dir, "unet")
63
  text_encoder_sub_dir = os.path.join(ckpt_dir, "text_encoder")
64
 
 
122
 
123
  pipe.unet.load_state_dict({k: lora_scale*v for k, v in pipe.unet.state_dict().items()})
124
  pipe.text_encoder.load_state_dict({k: lora_scale*v for k, v in pipe.text_encoder.state_dict().items()})
125
+
126
+ if tiny_vae:
127
+ pipe.vae = AutoencoderTiny.from_pretrained("madebyollin/taesd", torch_dtype=torch_dtype)
128
+
129
+ if ddim:
130
+ pipe.scheduler = DDIMScheduler.from_config(pipe.scheduler.config)
131
 
132
  if torch_dtype in (torch.float16, torch.bfloat16):
133
  pipe.unet.half()
 
141
  pipe.to(device)
142
 
143
  if del_background:
144
+ return remove(pipe(**params).images[0],
145
+ alpha_matting=alpha_matting,
146
+ alpha_matting_foreground_threshold=alpha_matting_foreground_threshold,
147
+ alpha_matting_background_threshold=alpha_matting_background_threshold,
148
+ alpha_matting_erode_size=alpha_matting_erode_size,
149
+ post_process_mask=post_process_mask
150
+ )
151
  else:
152
  return pipe(**params).images[0]
153
 
 
167
  gr.Markdown(" # Text-to-Image demo")
168
 
169
  with gr.Row():
170
+ model_id = gr.Dropdown(
171
+ label="Model ID",
172
+ choices=[model_id_default,
173
+ "nota-ai/bk-sdm-v2-base",
174
+ "nota-ai/bk-sdm-v2-small"],
175
+ value=model_id_default,
176
+ max_choices=1
177
+ )
178
+
179
 
180
  prompt = gr.Textbox(
181
  label="Prompt",
 
221
  step=1,
222
  value=20, # Replace with defaults that work for your model
223
  )
224
+ with gr.Row():
225
+ tiny_vae = = gr.Checkbox(
226
+ label="Use AutoencoderTiny?",
227
+ value=False
228
+ )
229
+ ddim = = gr.Checkbox(
230
+ label="Use DDIMScheduler?",
231
+ value=False
232
+ )
233
+
234
  with gr.Row():
235
  del_background = gr.Checkbox(
236
  label="Delete background?",
237
  value=False
238
  )
239
+ with gr.Column(visible=False) as rembg_params:
240
+ alpha_matting = gr.Checkbox(
241
+ label="alpha_matting",
242
+ value=False
243
+ )
244
+ with gr.Column(visible=False) as alpha_params:
245
+ alpha_matting_foreground_threshold = gr.Slider(
246
+ label="alpha_matting_foreground_threshold",
247
+ minimum=0,
248
+ maximum=255,
249
+ step=1,
250
+ value=240,
251
+ )
252
+ alpha_matting_background_threshold = gr.Slider(
253
+ label="alpha_matting_background_threshold",
254
+ minimum=0,
255
+ maximum=255,
256
+ step=1,
257
+ value=10,
258
+ )
259
+ alpha_matting_erode_size = gr.Slider(
260
+ label="alpha_matting_erode_size",
261
+ minimum=0,
262
+ maximum=100,
263
+ step=1,
264
+ value=10,
265
+ )
266
+ alpha_matting.change(
267
+ fn=lambda x: gr.Row.update(visible=x),
268
+ inputs=alpha_matting,
269
+ outputs=alpha_params
270
+ )
271
+ post_process_mask = gr.Checkbox(
272
+ label="post_process_mask",
273
+ value=False
274
+ )
275
+
276
  with gr.Row():
277
  controlnet_checkbox = gr.Checkbox(
278
  label="ControlNet",
 
372
  ip_adapter_checkbox,
373
  ip_adapter_scale,
374
  ip_adapter_image,
375
+ tiny_vae,
376
+ ddim,
377
  del_background,
378
+ alpha_matting,
379
+ alpha_matting_foreground_threshold,
380
+ alpha_matting_background_threshold,
381
+ alpha_matting_erode_size,
382
+ post_process_mask,
383
  ],
384
  outputs=[result],
385
  )