Jinl commited on
Commit
9bf9ce7
·
1 Parent(s): 9c3237f

initial add

Browse files
Files changed (34) hide show
  1. .gitignore +16 -0
  2. app.py +529 -0
  3. app.sh +7 -0
  4. pdiff/pdiff_pipeline.py +275 -0
  5. requirements.txt +6 -0
  6. style.css +3 -0
  7. utils/__init__.py +0 -0
  8. utils/__pycache__/__init__.cpython-310.pyc +0 -0
  9. utils/__pycache__/__init__.cpython-38.pyc +0 -0
  10. utils/__pycache__/__init__.cpython-39.pyc +0 -0
  11. utils/__pycache__/convert_from_ckpt.cpython-310.pyc +0 -0
  12. utils/__pycache__/convert_from_ckpt.cpython-38.pyc +0 -0
  13. utils/__pycache__/convert_from_ckpt.cpython-39.pyc +0 -0
  14. utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc +0 -0
  15. utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc +0 -0
  16. utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-39.pyc +0 -0
  17. utils/__pycache__/diffuser_utils.cpython-310.pyc +0 -0
  18. utils/__pycache__/diffuser_utils.cpython-38.pyc +0 -0
  19. utils/__pycache__/diffuser_utils.cpython-39.pyc +0 -0
  20. utils/__pycache__/free_lunch_utils.cpython-310.pyc +0 -0
  21. utils/__pycache__/free_lunch_utils.cpython-38.pyc +0 -0
  22. utils/__pycache__/free_lunch_utils.cpython-39.pyc +0 -0
  23. utils/__pycache__/masactrl_utils.cpython-310.pyc +0 -0
  24. utils/__pycache__/masactrl_utils.cpython-38.pyc +0 -0
  25. utils/__pycache__/masactrl_utils.cpython-39.pyc +0 -0
  26. utils/__pycache__/style_attn_control.cpython-310.pyc +0 -0
  27. utils/__pycache__/style_attn_control.cpython-38.pyc +0 -0
  28. utils/__pycache__/style_attn_control.cpython-39.pyc +0 -0
  29. utils/convert_from_ckpt.py +959 -0
  30. utils/convert_lora_safetensor_to_diffusers.py +154 -0
  31. utils/diffuser_utils.py +275 -0
  32. utils/free_lunch_utils.py +334 -0
  33. utils/masactrl_utils.py +212 -0
  34. utils/style_attn_control.py +275 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ./data
2
+ ./results
3
+ ./results_ablation
4
+ ./workdir
5
+ ./row_results
6
+ ./new_res
7
+ ./cop
8
+ examper
9
+ results
10
+ data
11
+ results_ablation
12
+ row_results
13
+ new_res
14
+ cop
15
+ ./samples
16
+ samples
app.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import re
4
+ from turtle import width
5
+ import torch
6
+ import random
7
+ import numpy as np
8
+ import gradio as gr
9
+ from glob import glob
10
+ from omegaconf import OmegaConf
11
+ from datetime import datetime
12
+ from safetensors import safe_open
13
+
14
+ from diffusers import AutoencoderKL,UNet2DConditionModel,StableDiffusionPipeline
15
+ from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
16
+ from diffusers.utils.import_utils import is_xformers_available
17
+ from transformers import CLIPTextModel, CLIPTokenizer
18
+ from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
19
+ from utils.convert_lora_safetensor_to_diffusers import convert_lora
20
+
21
+ import torch.nn.functional as F
22
+ from PIL import Image
23
+
24
+ from utils.diffuser_utils import MasaCtrlPipeline
25
+ from utils.masactrl_utils import (AttentionBase,
26
+ regiter_attention_editor_diffusers)
27
+ from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d
28
+
29
+ from utils.style_attn_control import MaskPromptedStyleAttentionControl
30
+ from torchvision.utils import save_image
31
+ from diffusers.models.attention_processor import AttnProcessor2_0
32
+
33
+
34
+
35
+
36
+ css = """
37
+ .toolbutton {
38
+ margin-buttom: 0em 0em 0em 0em;
39
+ max-width: 2.5em;
40
+ min-width: 2.5em !important;
41
+ height: 2.5em;
42
+ }
43
+ """
44
+
45
+ class GlobalText:
46
+ def __init__(self):
47
+
48
+ # config dirs
49
+ self.basedir = os.getcwd()
50
+ self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
51
+ self.personalized_model_dir = '/home/jin.liu/liujin/webui/stable-diffusion-webui/models/Stable-diffusion'
52
+ self.lora_model_dir = '/home/jin.liu/liujin/webui/stable-diffusion-webui/models/Lora'
53
+ self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
54
+ self.savedir_sample = os.path.join(self.savedir, "sample")
55
+
56
+ self.savedir_mask = os.path.join(self.savedir, "mask")
57
+
58
+ self.stable_diffusion_list = ["/home/jin.liu/liujin/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9",
59
+ "runwayml/stable-diffusion-v1-5",
60
+ "stabilityai/stable-diffusion-2-1"]
61
+ self.personalized_model_list = []
62
+ self.lora_model_list = []
63
+
64
+ # config models
65
+ self.tokenizer = None
66
+ self.text_encoder = None
67
+ self.vae = None
68
+ self.unet = None
69
+ self.pipeline = None
70
+ self.lora_loaded = None
71
+ self.personal_model_loaded = None
72
+ self.lora_model_state_dict = {}
73
+ self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
74
+ self.refresh_stable_diffusion()
75
+ self.refresh_personalized_model()
76
+
77
+ self.reset_start_code()
78
+ def load_base_pipeline(self, model_path):
79
+ print(f'loading {model_path} model')
80
+ scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler")
81
+ self.pipeline = MasaCtrlPipeline.from_pretrained(model_path,
82
+ scheduler=scheduler).to(self.device)
83
+
84
+ def refresh_stable_diffusion(self):
85
+
86
+ self.load_base_pipeline(self.stable_diffusion_list[0])
87
+ self.lora_loaded = None
88
+ self.personal_model_loaded = None
89
+ return self.stable_diffusion_list[0]
90
+
91
+ def refresh_personalized_model(self):
92
+ personalized_model_list = glob(os.path.join(self.personalized_model_dir, "**/*.safetensors"), recursive=True)
93
+ self.personalized_model_list = {os.path.basename(file): file for file in personalized_model_list}
94
+
95
+ lora_model_list = glob(os.path.join(self.lora_model_dir, "**/*.safetensors"), recursive=True)
96
+ self.lora_model_list = {os.path.basename(file): file for file in lora_model_list}
97
+
98
+ def update_stable_diffusion(self, stable_diffusion_dropdown):
99
+
100
+ self.load_base_pipeline(stable_diffusion_dropdown)
101
+ self.lora_loaded = None
102
+ self.personal_model_loaded = None
103
+ return gr.Dropdown.update()
104
+
105
+ def update_base_model(self, base_model_dropdown):
106
+ if self.pipeline is None:
107
+ gr.Info(f"Please select a pretrained model path.")
108
+ return None
109
+ else:
110
+ base_model = self.personalized_model_list[base_model_dropdown]
111
+ mid_model = StableDiffusionPipeline.from_single_file(base_model)
112
+ self.pipeline.vae = mid_model.vae
113
+ self.pipeline.unet = mid_model.unet
114
+ self.pipeline.text_encoder = mid_model.text_encoder
115
+ self.pipeline.to(self.device)
116
+ self.personal_model_loaded = base_model_dropdown.split('.')[0]
117
+ print(f'load {base_model_dropdown} model success!')
118
+
119
+ return gr.Dropdown()
120
+
121
+
122
+ def update_lora_model(self, lora_model_dropdown,lora_alpha_slider):
123
+
124
+ if self.pipeline is None:
125
+ gr.Info(f"Please select a pretrained model path.")
126
+ return None
127
+ else:
128
+ if lora_model_dropdown == "none":
129
+ self.pipeline.unfuse_lora()
130
+ self.pipeline.unload_lora_weights()
131
+ self.lora_loaded = None
132
+ # self.personal_model_loaded = None
133
+ print("Restore lora.")
134
+ else:
135
+
136
+ lora_model_path = self.lora_model_list[lora_model_dropdown]#os.path.join(self.lora_model_dir, lora_model_dropdown)
137
+ # self.lora_model_state_dict = {}
138
+ # if lora_model_dropdown == "none": pass
139
+ # else:
140
+ # with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
141
+ # for key in f.keys():
142
+ # self.lora_model_state_dict[key] = f.get_tensor(key)
143
+ # convert_lora(self.pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
144
+ self.pipeline.unfuse_lora()
145
+ self.pipeline.unload_lora_weights()
146
+ self.pipeline.load_lora_weights(lora_model_path)
147
+ self.pipeline.fuse_lora(lora_alpha_slider)
148
+ self.lora_loaded = lora_model_dropdown.split('.')[0]
149
+ print(f'load {lora_model_dropdown} model success!')
150
+ return gr.Dropdown()
151
+
152
+ def generate(self, source, style, source_mask, style_mask,
153
+ start_step, start_layer, Style_attn_step,
154
+ Method, Style_Guidance, ddim_steps, scale, seed, de_bug,
155
+ target_prompt, negative_prompt_textbox,
156
+ inter_latents,
157
+ freeu, b1, b2, s1, s2,
158
+ width_slider,height_slider,
159
+ ):
160
+ os.makedirs(self.savedir, exist_ok=True)
161
+ os.makedirs(self.savedir_sample, exist_ok=True)
162
+ os.makedirs(self.savedir_mask, exist_ok=True)
163
+ model = self.pipeline
164
+
165
+ if seed != -1 and seed != "": torch.manual_seed(int(seed))
166
+ else: torch.seed()
167
+ seed = torch.initial_seed()
168
+ sample_count = len(os.listdir(self.savedir_sample))
169
+ os.makedirs(os.path.join(self.savedir_mask, f"results_{sample_count}"), exist_ok=True)
170
+
171
+ # ref_prompt = [source_prompt, target_prompt]
172
+ # prompts = ref_prompt+['']
173
+ ref_prompt = [target_prompt, target_prompt]
174
+ prompts = ref_prompt+[target_prompt]
175
+ source_image,style_image,source_mask,style_mask = load_mask_images(source,style,source_mask,style_mask,self.device,width_slider,height_slider,out_dir=os.path.join(self.savedir_mask, f"results_{sample_count}"))
176
+
177
+
178
+ # global START_CODE, LATENTS_LIST
179
+
180
+ with torch.no_grad():
181
+ #import pdb;pdb.set_trace()
182
+
183
+ #prev_source
184
+ if self.start_code is None and self.latents_list is None:
185
+ content_style = torch.cat([style_image, source_image], dim=0)
186
+ editor = AttentionBase()
187
+ regiter_attention_editor_diffusers(model, editor)
188
+ st_code, latents_list = model.invert(content_style,
189
+ ref_prompt,
190
+ guidance_scale=scale,
191
+ num_inference_steps=ddim_steps,
192
+ return_intermediates=True)
193
+ start_code = torch.cat([st_code, st_code[1:]], dim=0)
194
+ self.start_code = start_code
195
+ self.latents_list = latents_list
196
+ else:
197
+ start_code = self.start_code
198
+ latents_list = self.latents_list
199
+ print('------------------------------------------ Use previous latents ------------------------------------------ ')
200
+
201
+ #["Without mask", "Only masked region", "Seperate Background Foreground"]
202
+
203
+ if Method == "Without mask":
204
+ style_mask = None
205
+ source_mask = None
206
+ only_masked_region = False
207
+ elif Method == "Only masked region":
208
+ assert style_mask is not None and source_mask is not None
209
+ only_masked_region = True
210
+ else:
211
+ assert style_mask is not None and source_mask is not None
212
+ only_masked_region = False
213
+
214
+ controller = MaskPromptedStyleAttentionControl(start_step, start_layer,
215
+ style_attn_step=Style_attn_step,
216
+ style_guidance=Style_Guidance,
217
+ style_mask=style_mask,
218
+ source_mask=source_mask,
219
+ only_masked_region=only_masked_region,
220
+ guidance=scale,
221
+ de_bug=de_bug,
222
+ )
223
+ if freeu:
224
+ print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++')
225
+ if Method != "Without mask":
226
+ register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
227
+ register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
228
+ else:
229
+ register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None)
230
+ register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None)
231
+
232
+ else:
233
+ print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++')
234
+ register_upblock2d(model)
235
+ register_crossattn_upblock2d(model)
236
+ regiter_attention_editor_diffusers(model, controller)
237
+
238
+ regiter_attention_editor_diffusers(model, controller)
239
+
240
+ # inference the synthesized image
241
+ generate_image= model(prompts,
242
+ width=width_slider,
243
+ height=height_slider,
244
+ latents=start_code,
245
+ guidance_scale=scale,
246
+ num_inference_steps=ddim_steps,
247
+ ref_intermediate_latents=latents_list if inter_latents else None,
248
+ neg_prompt=negative_prompt_textbox,
249
+ return_intermediates=False,)
250
+
251
+ # os.makedirs(os.path.join(output_dir, f"results_{sample_count}"))
252
+ save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg"
253
+ if self.lora_loaded != None:
254
+ save_file_name = f"lora_{self.lora_loaded}_" + save_file_name
255
+ if self.personal_model_loaded != None:
256
+ save_file_name = f"personal_{self.personal_model_loaded}_" + save_file_name
257
+ #f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}_lora_{self.lora_loaded}.jpg"
258
+ save_file_path = os.path.join(self.savedir_sample, save_file_name)
259
+ #save_file_name = os.path.join(output_dir, f"results_style_{style_name}", f"{content_name}.jpg")
260
+
261
+ save_image(torch.cat([source_image/2 + 0.5, style_image/2 + 0.5, generate_image[2:]], dim=0), save_file_path, nrow=3, padding=0)
262
+
263
+
264
+
265
+ # global OUTPUT_RESULT
266
+ # OUTPUT_RESULT = save_file_name
267
+
268
+ generate_image = generate_image.cpu().permute(0, 2, 3, 1).numpy()
269
+ #save_gif(latents_list, os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif'))
270
+ # import pdb;pdb.set_trace()
271
+ #gif_dir = os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif')
272
+
273
+ return [
274
+ generate_image[0],
275
+ generate_image[1],
276
+ generate_image[2],
277
+ ]
278
+
279
+ def reset_start_code(self,):
280
+ self.start_code = None
281
+ self.latents_list = None
282
+
283
+ global_text = GlobalText()
284
+
285
+
286
+ def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None):
287
+ # invert the image into noise map
288
+ if isinstance(source['image'], np.ndarray):
289
+ source_image = torch.from_numpy(source['image']).to(device) / 127.5 - 1.
290
+ else:
291
+ source_image = torch.from_numpy(np.array(source['image'])).to(device) / 127.5 - 1.
292
+ source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2)
293
+
294
+ source_image = F.interpolate(source_image, (height,width ))
295
+
296
+ if out_dir is not None and source_mask is None:
297
+
298
+ source['mask'].save(os.path.join(out_dir,'source_mask.jpg'))
299
+ else:
300
+ Image.fromarray(source_mask).save(os.path.join(out_dir,'source_mask.jpg'))
301
+ if out_dir is not None and style_mask is None:
302
+
303
+ style['mask'].save(os.path.join(out_dir,'style_mask.jpg'))
304
+ else:
305
+ Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg'))
306
+ # save source['mask']
307
+ # import pdb;pdb.set_trace()
308
+ source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255.
309
+ source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
310
+ source_mask = F.interpolate(source_mask, (height//8,width//8))
311
+
312
+ if isinstance(source['image'], np.ndarray):
313
+ style_image = torch.from_numpy(style['image']).to(device) / 127.5 - 1.
314
+ else:
315
+ style_image = torch.from_numpy(np.array(style['image'])).to(device) / 127.5 - 1.
316
+ style_image = style_image.unsqueeze(0).permute(0, 3, 1, 2)
317
+ style_image = F.interpolate(style_image, (height,width))
318
+
319
+ style_mask = torch.from_numpy(np.array(style['mask']) if style_mask is None else style_mask ).to(device) / 255.
320
+ style_mask = style_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
321
+ style_mask = F.interpolate(style_mask, (height//8,width//8))
322
+
323
+
324
+ return source_image,style_image,source_mask,style_mask
325
+
326
+
327
+ def ui():
328
+ with gr.Blocks(css=css) as demo:
329
+ gr.Markdown(
330
+ """
331
+ # [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000)
332
+ Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br>
333
+ [Arxiv Report](https://arxiv.org/abs/0000) | [Project Page](https://www.github.io/) | [Github](https://github.com/)
334
+ """
335
+ )
336
+ with gr.Column(variant="panel"):
337
+ gr.Markdown(
338
+ """
339
+ ### 1. Select a pretrained model.
340
+ """
341
+ )
342
+ with gr.Row():
343
+ stable_diffusion_dropdown = gr.Dropdown(
344
+ label="Pretrained Model Path",
345
+ choices=global_text.stable_diffusion_list,
346
+ interactive=True,
347
+ allow_custom_value=True
348
+ )
349
+ stable_diffusion_dropdown.change(fn=global_text.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
350
+
351
+ stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
352
+ def update_stable_diffusion():
353
+ global_text.refresh_stable_diffusion()
354
+
355
+ stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[])
356
+
357
+ base_model_dropdown = gr.Dropdown(
358
+ label="Select a ckpt model (optional)",
359
+ choices=sorted(list(global_text.personalized_model_list.keys())),
360
+ interactive=True,
361
+ allow_custom_value=True,
362
+ )
363
+ base_model_dropdown.change(fn=global_text.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
364
+
365
+ lora_model_dropdown = gr.Dropdown(
366
+ label="Select a LoRA model (optional)",
367
+ choices=["none"] + sorted(list(global_text.lora_model_list.keys())),
368
+ value="none",
369
+ interactive=True,
370
+ allow_custom_value=True,
371
+ )
372
+ lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
373
+ lora_model_dropdown.change(fn=global_text.update_lora_model, inputs=[lora_model_dropdown,lora_alpha_slider], outputs=[lora_model_dropdown])
374
+
375
+
376
+
377
+ personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
378
+
379
+ def update_personalized_model():
380
+ global_text.refresh_personalized_model()
381
+ return [
382
+ gr.Dropdown(choices=sorted(list(global_text.personalized_model_list.keys()))),
383
+ gr.Dropdown(choices=["none"] + sorted(list(global_text.lora_model_list.keys())))
384
+ ]
385
+ personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
386
+
387
+
388
+ with gr.Column(variant="panel"):
389
+ gr.Markdown(
390
+ """
391
+ ### 2. Configs for PortraitDiff.
392
+ """
393
+ )
394
+ with gr.Tab("Configs"):
395
+
396
+ with gr.Row():
397
+ source_image = gr.Image(label="Source Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512)
398
+ style_image = gr.Image(label="Style Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512)
399
+ with gr.Row():
400
+ prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1)
401
+ negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
402
+ # output_dir = gr.Textbox(label="output_dir", value='./results/')
403
+
404
+ with gr.Row().style(equal_height=False):
405
+ with gr.Column():
406
+ width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
407
+ height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
408
+ Method = gr.Dropdown(
409
+ ["Without mask", "Only masked region", "Seperate Background Foreground"],
410
+ value="Without mask",
411
+ label="Mask", info="Select how to use masks")
412
+ with gr.Tab('Base Configs'):
413
+ with gr.Row():
414
+ # sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
415
+ ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=10, maximum=100, step=1)
416
+
417
+ Style_attn_step = gr.Slider(label="Step of Style Attention Control",
418
+ minimum=0,
419
+ maximum=50,
420
+ value=35,
421
+ step=1)
422
+ start_step = gr.Slider(label="Step of Attention Control",
423
+ minimum=0,
424
+ maximum=150,
425
+ value=0,
426
+ step=1)
427
+ start_layer = gr.Slider(label="Layer of Style Attention Control",
428
+ minimum=0,
429
+ maximum=16,
430
+ value=10,
431
+ step=1)
432
+ Style_Guidance = gr.Slider(label="Style Guidance Scale",
433
+ minimum=0,
434
+ maximum=4,
435
+ value=1.2,
436
+ step=0.05)
437
+ cfg_scale_slider = gr.Slider(label="CFG Scale", value=0, minimum=0, maximum=20)
438
+
439
+
440
+ with gr.Tab('FreeU'):
441
+ with gr.Row():
442
+ freeu = gr.Checkbox(label="Free Upblock", value=False)
443
+ de_bug = gr.Checkbox(value=False,label='DeBug')
444
+ inter_latents = gr.Checkbox(value=True,label='Use intermediate latents')
445
+ with gr.Row():
446
+ b1 = gr.Slider(label='b1:',
447
+ minimum=-1,
448
+ maximum=2,
449
+ step=0.01,
450
+ value=1.3)
451
+ b2 = gr.Slider(label='b2:',
452
+ minimum=-1,
453
+ maximum=2,
454
+ step=0.01,
455
+ value=1.5)
456
+ with gr.Row():
457
+ s1 = gr.Slider(label='s1: ',
458
+ minimum=0,
459
+ maximum=2,
460
+ step=0.1,
461
+ value=1.0)
462
+ s2 = gr.Slider(label='s2:',
463
+ minimum=0,
464
+ maximum=2,
465
+ step=0.1,
466
+ value=1.0)
467
+ with gr.Row():
468
+ seed_textbox = gr.Textbox(label="Seed", value=-1)
469
+ seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
470
+ seed_button.click(fn=lambda: random.randint(1, 1e8), inputs=[], outputs=[seed_textbox])
471
+
472
+ with gr.Column():
473
+ generate_button = gr.Button(value="Generate", variant='primary')
474
+
475
+ generate_image = gr.Image(label="Image with PortraitDiff", interactive=False, type='numpy', height=512,)
476
+
477
+ with gr.Row():
478
+ recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256)
479
+ recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256)
480
+
481
+ with gr.Tab("SAM"):
482
+ with gr.Column():
483
+ add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
484
+ with gr.Row():
485
+ sam_source_btn = gr.Button(value="SAM Source")
486
+ send_source_btn = gr.Button(value="Send Source")
487
+
488
+ sam_style_btn = gr.Button(value="SAM Style")
489
+ send_style_btn = gr.Button(value="Send Style")
490
+ with gr.Row():
491
+ source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
492
+ style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
493
+
494
+ with gr.Row():
495
+ source_image_with_points = gr.Image(label="source Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
496
+ source_mask = gr.Image(label="Source Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
497
+
498
+ style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
499
+ style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
500
+
501
+ gr.Examples(
502
+ [[os.path.join(os.path.dirname(__file__), "gradio_app/images/content/1.jpg"),
503
+ os.path.join(os.path.dirname(__file__), "gradio_app/images/style/1.jpg")],
504
+
505
+ ],
506
+ [source_image, style_image]
507
+ )
508
+ inputs = [
509
+ source_image, style_image, source_mask, style_mask,
510
+ start_step, start_layer, Style_attn_step,
511
+ Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug,
512
+ prompt_textbox, negative_prompt_textbox, inter_latents,
513
+ freeu, b1, b2, s1, s2,
514
+ width_slider,height_slider,
515
+ ]
516
+
517
+ generate_button.click(
518
+ fn=global_text.generate,
519
+ inputs=inputs,
520
+ outputs=[recons_style,recons_content,generate_image]
521
+ )
522
+ source_image.upload(global_text.reset_start_code, inputs=[], outputs=[])
523
+ style_image.upload(global_text.reset_start_code, inputs=[], outputs=[])
524
+ ddim_steps.change(fn=global_text.reset_start_code, inputs=[], outputs=[])
525
+ return demo
526
+
527
+ if __name__ == "__main__":
528
+ demo = ui()
529
+ demo.launch(server_name="172.18.32.44")
app.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ export CUDA_VISIBLE_DEVICES=$1
4
+
5
+ echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
6
+ # export CUDA_VISIBLE_DEVICES=5
7
+ python gapp.py
pdiff/pdiff_pipeline.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Util functions based on Diffuser framework.
3
+ """
4
+
5
+
6
+ import os
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch.nn.functional as F
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ from torchvision.utils import save_image
15
+ from torchvision.io import read_image
16
+
17
+ from diffusers import StableDiffusionPipeline
18
+
19
+ from pytorch_lightning import seed_everything
20
+
21
+
22
+ class MasaCtrlPipeline(StableDiffusionPipeline):
23
+
24
+ def next_step(
25
+ self,
26
+ model_output: torch.FloatTensor,
27
+ timestep: int,
28
+ x: torch.FloatTensor,
29
+ eta=0.,
30
+ verbose=False
31
+ ):
32
+ """
33
+ Inverse sampling for DDIM Inversion
34
+ """
35
+ if verbose:
36
+ print("timestep: ", timestep)
37
+ next_step = timestep
38
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
39
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
40
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
41
+ beta_prod_t = 1 - alpha_prod_t
42
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
43
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
44
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
45
+ return x_next, pred_x0
46
+
47
+ def step(
48
+ self,
49
+ model_output: torch.FloatTensor,
50
+ timestep: int,
51
+ x: torch.FloatTensor,
52
+ eta: float=0.0,
53
+ verbose=False,
54
+ ):
55
+ """
56
+ predict the sampe the next step in the denoise process.
57
+ """
58
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
59
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
60
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
61
+ beta_prod_t = 1 - alpha_prod_t
62
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
63
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
64
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
65
+ return x_prev, pred_x0
66
+
67
+ @torch.no_grad()
68
+ def image2latent(self, image):
69
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
70
+ if type(image) is Image:
71
+ image = np.array(image)
72
+ image = torch.from_numpy(image).float() / 127.5 - 1
73
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
74
+ # input image density range [-1, 1]
75
+ latents = self.vae.encode(image)['latent_dist'].mean
76
+ latents = latents * 0.18215
77
+ return latents
78
+
79
+ @torch.no_grad()
80
+ def latent2image(self, latents, return_type='np'):
81
+ latents = 1 / 0.18215 * latents.detach()
82
+ image = self.vae.decode(latents)['sample']
83
+ if return_type == 'np':
84
+ image = (image / 2 + 0.5).clamp(0, 1)
85
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
86
+ image = (image * 255).astype(np.uint8)
87
+ elif return_type == "pt":
88
+ image = (image / 2 + 0.5).clamp(0, 1)
89
+
90
+ return image
91
+
92
+ def latent2image_grad(self, latents):
93
+ latents = 1 / 0.18215 * latents
94
+ image = self.vae.decode(latents)['sample']
95
+
96
+ return image # range [-1, 1]
97
+
98
+ @torch.no_grad()
99
+ def __call__(
100
+ self,
101
+ prompt,
102
+ batch_size=1,
103
+ height=512,
104
+ width=512,
105
+ num_inference_steps=50,
106
+ guidance_scale=7.5,
107
+ eta=0.0,
108
+ latents=None,
109
+ unconditioning=None,
110
+ neg_prompt=None,
111
+ ref_intermediate_latents=None,
112
+ return_intermediates=False,
113
+ **kwds):
114
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
115
+ if isinstance(prompt, list):
116
+ batch_size = len(prompt)
117
+ elif isinstance(prompt, str):
118
+ if batch_size > 1:
119
+ prompt = [prompt] * batch_size
120
+
121
+ # text embeddings
122
+ text_input = self.tokenizer(
123
+ prompt,
124
+ padding="max_length",
125
+ max_length=77,
126
+ return_tensors="pt"
127
+ )
128
+
129
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
130
+ print("input text embeddings :", text_embeddings.shape)
131
+ if kwds.get("dir"):
132
+ dir = text_embeddings[-2] - text_embeddings[-1]
133
+ u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
134
+ text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
135
+ print(u.shape)
136
+ print(v.shape)
137
+
138
+ # define initial latents
139
+ latents_shape = (batch_size, self.unet.config.in_channels, height//8, width//8)
140
+ if latents is None:
141
+ latents = torch.randn(latents_shape, device=DEVICE)
142
+ else:
143
+ assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
144
+
145
+ # unconditional embedding for classifier free guidance
146
+ if guidance_scale > 1.:
147
+ max_length = text_input.input_ids.shape[-1]
148
+ if neg_prompt:
149
+ uc_text = neg_prompt
150
+ else:
151
+ uc_text = ""
152
+ # uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
153
+ unconditional_input = self.tokenizer(
154
+ [uc_text] * batch_size,
155
+ padding="max_length",
156
+ max_length=77,
157
+ return_tensors="pt"
158
+ )
159
+ # unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
160
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
161
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
162
+
163
+ print("latents shape: ", latents.shape)
164
+ # iterative sampling
165
+ self.scheduler.set_timesteps(num_inference_steps)
166
+ # print("Valid timesteps: ", reversed(self.scheduler.timesteps))
167
+ latents_list = [latents]
168
+ pred_x0_list = [latents]
169
+ for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
170
+ if ref_intermediate_latents is not None:
171
+ # note that the batch_size >= 2
172
+ latents_ref = ref_intermediate_latents[-1 - i]
173
+ _, latents_cur = latents.chunk(2)
174
+ latents = torch.cat([latents_ref, latents_cur])
175
+
176
+ if guidance_scale > 1.:
177
+ model_inputs = torch.cat([latents] * 2)
178
+ else:
179
+ model_inputs = latents
180
+ if unconditioning is not None and isinstance(unconditioning, list):
181
+ _, text_embeddings = text_embeddings.chunk(2)
182
+ text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
183
+ # predict tghe noise
184
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
185
+ if guidance_scale > 1.:
186
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
187
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
188
+ # compute the previous noise sample x_t -> x_t-1
189
+ latents, pred_x0 = self.step(noise_pred, t, latents)
190
+ latents_list.append(latents)
191
+ pred_x0_list.append(pred_x0)
192
+
193
+ image = self.latent2image(latents, return_type="pt")
194
+ if return_intermediates:
195
+ pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
196
+ latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
197
+ return image, pred_x0_list, latents_list
198
+ return image
199
+
200
+ @torch.no_grad()
201
+ def invert(
202
+ self,
203
+ image: torch.Tensor,
204
+ prompt,
205
+ num_inference_steps=50,
206
+ guidance_scale=7.5,
207
+ eta=0.0,
208
+ return_intermediates=False,
209
+ **kwds):
210
+ """
211
+ invert a real image into noise map with determinisc DDIM inversion
212
+ """
213
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214
+ batch_size = image.shape[0]
215
+ if isinstance(prompt, list):
216
+ if batch_size == 1:
217
+ image = image.expand(len(prompt), -1, -1, -1)
218
+ elif isinstance(prompt, str):
219
+ if batch_size > 1:
220
+ prompt = [prompt] * batch_size
221
+
222
+ # text embeddings
223
+ text_input = self.tokenizer(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=77,
227
+ return_tensors="pt"
228
+ )
229
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
230
+ print("input text embeddings :", text_embeddings.shape)
231
+ # define initial latents
232
+ latents = self.image2latent(image)
233
+ start_latents = latents
234
+ # print(latents)
235
+ # exit()
236
+ # unconditional embedding for classifier free guidance
237
+ if guidance_scale > 1.:
238
+ max_length = text_input.input_ids.shape[-1]
239
+ unconditional_input = self.tokenizer(
240
+ [""] * batch_size,
241
+ padding="max_length",
242
+ max_length=77,
243
+ return_tensors="pt"
244
+ )
245
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
246
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
247
+
248
+ print("latents shape: ", latents.shape)
249
+ # interative sampling
250
+ self.scheduler.set_timesteps(num_inference_steps)
251
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
252
+ # print("attributes: ", self.scheduler.__dict__)
253
+ latents_list = [latents]
254
+ pred_x0_list = [latents]
255
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
256
+ if guidance_scale > 1.:
257
+ model_inputs = torch.cat([latents] * 2)
258
+ else:
259
+ model_inputs = latents
260
+
261
+ # predict the noise
262
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
263
+ if guidance_scale > 1.:
264
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
265
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
266
+ # compute the previous noise sample x_t-1 -> x_t
267
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
268
+ latents_list.append(latents)
269
+ pred_x0_list.append(pred_x0)
270
+
271
+ if return_intermediates:
272
+ # return the intermediate laters during inversion
273
+ # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
274
+ return latents, latents_list
275
+ return latents, start_latents
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ diffusers==0.15.0
2
+ transformers
3
+ opencv-python
4
+ einops
5
+ omegaconf
6
+ pytorch_lightning
style.css ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ h1 {
2
+ text-align: center;
3
+ }
utils/__init__.py ADDED
File without changes
utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (141 Bytes). View file
 
utils/__pycache__/__init__.cpython-38.pyc ADDED
Binary file (139 Bytes). View file
 
utils/__pycache__/__init__.cpython-39.pyc ADDED
Binary file (139 Bytes). View file
 
utils/__pycache__/convert_from_ckpt.cpython-310.pyc ADDED
Binary file (27.2 kB). View file
 
utils/__pycache__/convert_from_ckpt.cpython-38.pyc ADDED
Binary file (28.2 kB). View file
 
utils/__pycache__/convert_from_ckpt.cpython-39.pyc ADDED
Binary file (27.9 kB). View file
 
utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc ADDED
Binary file (3.36 kB). View file
 
utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc ADDED
Binary file (3.36 kB). View file
 
utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-39.pyc ADDED
Binary file (3.33 kB). View file
 
utils/__pycache__/diffuser_utils.cpython-310.pyc ADDED
Binary file (6.77 kB). View file
 
utils/__pycache__/diffuser_utils.cpython-38.pyc ADDED
Binary file (6.8 kB). View file
 
utils/__pycache__/diffuser_utils.cpython-39.pyc ADDED
Binary file (6.78 kB). View file
 
utils/__pycache__/free_lunch_utils.cpython-310.pyc ADDED
Binary file (8.28 kB). View file
 
utils/__pycache__/free_lunch_utils.cpython-38.pyc ADDED
Binary file (8.6 kB). View file
 
utils/__pycache__/free_lunch_utils.cpython-39.pyc ADDED
Binary file (8.59 kB). View file
 
utils/__pycache__/masactrl_utils.cpython-310.pyc ADDED
Binary file (6.2 kB). View file
 
utils/__pycache__/masactrl_utils.cpython-38.pyc ADDED
Binary file (6.71 kB). View file
 
utils/__pycache__/masactrl_utils.cpython-39.pyc ADDED
Binary file (6.69 kB). View file
 
utils/__pycache__/style_attn_control.cpython-310.pyc ADDED
Binary file (8.73 kB). View file
 
utils/__pycache__/style_attn_control.cpython-38.pyc ADDED
Binary file (8.67 kB). View file
 
utils/__pycache__/style_attn_control.cpython-39.pyc ADDED
Binary file (8.75 kB). View file
 
utils/convert_from_ckpt.py ADDED
@@ -0,0 +1,959 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The HuggingFace Inc. team.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+ """ Conversion script for the Stable Diffusion checkpoints."""
16
+
17
+ import re
18
+ from io import BytesIO
19
+ from typing import Optional
20
+
21
+ import requests
22
+ import torch
23
+ from transformers import (
24
+ AutoFeatureExtractor,
25
+ BertTokenizerFast,
26
+ CLIPImageProcessor,
27
+ CLIPTextModel,
28
+ CLIPTextModelWithProjection,
29
+ CLIPTokenizer,
30
+ CLIPVisionConfig,
31
+ CLIPVisionModelWithProjection,
32
+ )
33
+
34
+ from diffusers.models import (
35
+ AutoencoderKL,
36
+ PriorTransformer,
37
+ UNet2DConditionModel,
38
+ )
39
+ from diffusers.schedulers import (
40
+ DDIMScheduler,
41
+ DDPMScheduler,
42
+ DPMSolverMultistepScheduler,
43
+ EulerAncestralDiscreteScheduler,
44
+ EulerDiscreteScheduler,
45
+ HeunDiscreteScheduler,
46
+ LMSDiscreteScheduler,
47
+ PNDMScheduler,
48
+ UnCLIPScheduler,
49
+ )
50
+ from diffusers.utils.import_utils import BACKENDS_MAPPING
51
+
52
+
53
+ def shave_segments(path, n_shave_prefix_segments=1):
54
+ """
55
+ Removes segments. Positive values shave the first segments, negative shave the last segments.
56
+ """
57
+ if n_shave_prefix_segments >= 0:
58
+ return ".".join(path.split(".")[n_shave_prefix_segments:])
59
+ else:
60
+ return ".".join(path.split(".")[:n_shave_prefix_segments])
61
+
62
+
63
+ def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
64
+ """
65
+ Updates paths inside resnets to the new naming scheme (local renaming)
66
+ """
67
+ mapping = []
68
+ for old_item in old_list:
69
+ new_item = old_item.replace("in_layers.0", "norm1")
70
+ new_item = new_item.replace("in_layers.2", "conv1")
71
+
72
+ new_item = new_item.replace("out_layers.0", "norm2")
73
+ new_item = new_item.replace("out_layers.3", "conv2")
74
+
75
+ new_item = new_item.replace("emb_layers.1", "time_emb_proj")
76
+ new_item = new_item.replace("skip_connection", "conv_shortcut")
77
+
78
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
79
+
80
+ mapping.append({"old": old_item, "new": new_item})
81
+
82
+ return mapping
83
+
84
+
85
+ def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
86
+ """
87
+ Updates paths inside resnets to the new naming scheme (local renaming)
88
+ """
89
+ mapping = []
90
+ for old_item in old_list:
91
+ new_item = old_item
92
+
93
+ new_item = new_item.replace("nin_shortcut", "conv_shortcut")
94
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
95
+
96
+ mapping.append({"old": old_item, "new": new_item})
97
+
98
+ return mapping
99
+
100
+
101
+ def renew_attention_paths(old_list, n_shave_prefix_segments=0):
102
+ """
103
+ Updates paths inside attentions to the new naming scheme (local renaming)
104
+ """
105
+ mapping = []
106
+ for old_item in old_list:
107
+ new_item = old_item
108
+
109
+ # new_item = new_item.replace('norm.weight', 'group_norm.weight')
110
+ # new_item = new_item.replace('norm.bias', 'group_norm.bias')
111
+
112
+ # new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
113
+ # new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
114
+
115
+ # new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
116
+
117
+ mapping.append({"old": old_item, "new": new_item})
118
+
119
+ return mapping
120
+
121
+
122
+ def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
123
+ """
124
+ Updates paths inside attentions to the new naming scheme (local renaming)
125
+ """
126
+ mapping = []
127
+ for old_item in old_list:
128
+ new_item = old_item
129
+
130
+ new_item = new_item.replace("norm.weight", "group_norm.weight")
131
+ new_item = new_item.replace("norm.bias", "group_norm.bias")
132
+
133
+ new_item = new_item.replace("q.weight", "query.weight")
134
+ new_item = new_item.replace("q.bias", "query.bias")
135
+
136
+ new_item = new_item.replace("k.weight", "key.weight")
137
+ new_item = new_item.replace("k.bias", "key.bias")
138
+
139
+ new_item = new_item.replace("v.weight", "value.weight")
140
+ new_item = new_item.replace("v.bias", "value.bias")
141
+
142
+ new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
143
+ new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
144
+
145
+ new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
146
+
147
+ mapping.append({"old": old_item, "new": new_item})
148
+
149
+ return mapping
150
+
151
+
152
+ def assign_to_checkpoint(
153
+ paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
154
+ ):
155
+ """
156
+ This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
157
+ attention layers, and takes into account additional replacements that may arise.
158
+
159
+ Assigns the weights to the new checkpoint.
160
+ """
161
+ assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
162
+
163
+ # Splits the attention layers into three variables.
164
+ if attention_paths_to_split is not None:
165
+ for path, path_map in attention_paths_to_split.items():
166
+ old_tensor = old_checkpoint[path]
167
+ channels = old_tensor.shape[0] // 3
168
+
169
+ target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
170
+
171
+ num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
172
+
173
+ old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
174
+ query, key, value = old_tensor.split(channels // num_heads, dim=1)
175
+
176
+ checkpoint[path_map["query"]] = query.reshape(target_shape)
177
+ checkpoint[path_map["key"]] = key.reshape(target_shape)
178
+ checkpoint[path_map["value"]] = value.reshape(target_shape)
179
+
180
+ for path in paths:
181
+ new_path = path["new"]
182
+
183
+ # These have already been assigned
184
+ if attention_paths_to_split is not None and new_path in attention_paths_to_split:
185
+ continue
186
+
187
+ # Global renaming happens here
188
+ new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
189
+ new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
190
+ new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
191
+
192
+ if additional_replacements is not None:
193
+ for replacement in additional_replacements:
194
+ new_path = new_path.replace(replacement["old"], replacement["new"])
195
+
196
+ # proj_attn.weight has to be converted from conv 1D to linear
197
+ if "proj_attn.weight" in new_path:
198
+ checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
199
+ else:
200
+ checkpoint[new_path] = old_checkpoint[path["old"]]
201
+
202
+
203
+ def conv_attn_to_linear(checkpoint):
204
+ keys = list(checkpoint.keys())
205
+ attn_keys = ["query.weight", "key.weight", "value.weight"]
206
+ for key in keys:
207
+ if ".".join(key.split(".")[-2:]) in attn_keys:
208
+ if checkpoint[key].ndim > 2:
209
+ checkpoint[key] = checkpoint[key][:, :, 0, 0]
210
+ elif "proj_attn.weight" in key:
211
+ if checkpoint[key].ndim > 2:
212
+ checkpoint[key] = checkpoint[key][:, :, 0]
213
+
214
+
215
+ def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
216
+ """
217
+ Creates a config for the diffusers based on the config of the LDM model.
218
+ """
219
+ if controlnet:
220
+ unet_params = original_config.model.params.control_stage_config.params
221
+ else:
222
+ unet_params = original_config.model.params.unet_config.params
223
+
224
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
225
+
226
+ block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
227
+
228
+ down_block_types = []
229
+ resolution = 1
230
+ for i in range(len(block_out_channels)):
231
+ block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
232
+ down_block_types.append(block_type)
233
+ if i != len(block_out_channels) - 1:
234
+ resolution *= 2
235
+
236
+ up_block_types = []
237
+ for i in range(len(block_out_channels)):
238
+ block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
239
+ up_block_types.append(block_type)
240
+ resolution //= 2
241
+
242
+ vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
243
+
244
+ head_dim = unet_params.num_heads if "num_heads" in unet_params else None
245
+ use_linear_projection = (
246
+ unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
247
+ )
248
+ if use_linear_projection:
249
+ # stable diffusion 2-base-512 and 2-768
250
+ if head_dim is None:
251
+ head_dim = [5, 10, 20, 20]
252
+
253
+ class_embed_type = None
254
+ projection_class_embeddings_input_dim = None
255
+
256
+ if "num_classes" in unet_params:
257
+ if unet_params.num_classes == "sequential":
258
+ class_embed_type = "projection"
259
+ assert "adm_in_channels" in unet_params
260
+ projection_class_embeddings_input_dim = unet_params.adm_in_channels
261
+ else:
262
+ raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
263
+
264
+ config = {
265
+ "sample_size": image_size // vae_scale_factor,
266
+ "in_channels": unet_params.in_channels,
267
+ "down_block_types": tuple(down_block_types),
268
+ "block_out_channels": tuple(block_out_channels),
269
+ "layers_per_block": unet_params.num_res_blocks,
270
+ "cross_attention_dim": unet_params.context_dim,
271
+ "attention_head_dim": head_dim,
272
+ "use_linear_projection": use_linear_projection,
273
+ "class_embed_type": class_embed_type,
274
+ "projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
275
+ }
276
+
277
+ if not controlnet:
278
+ config["out_channels"] = unet_params.out_channels
279
+ config["up_block_types"] = tuple(up_block_types)
280
+
281
+ return config
282
+
283
+
284
+ def create_vae_diffusers_config(original_config, image_size: int):
285
+ """
286
+ Creates a config for the diffusers based on the config of the LDM model.
287
+ """
288
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
289
+ _ = original_config.model.params.first_stage_config.params.embed_dim
290
+
291
+ block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
292
+ down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
293
+ up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
294
+
295
+ config = {
296
+ "sample_size": image_size,
297
+ "in_channels": vae_params.in_channels,
298
+ "out_channels": vae_params.out_ch,
299
+ "down_block_types": tuple(down_block_types),
300
+ "up_block_types": tuple(up_block_types),
301
+ "block_out_channels": tuple(block_out_channels),
302
+ "latent_channels": vae_params.z_channels,
303
+ "layers_per_block": vae_params.num_res_blocks,
304
+ }
305
+ return config
306
+
307
+
308
+ def create_diffusers_schedular(original_config):
309
+ schedular = DDIMScheduler(
310
+ num_train_timesteps=original_config.model.params.timesteps,
311
+ beta_start=original_config.model.params.linear_start,
312
+ beta_end=original_config.model.params.linear_end,
313
+ beta_schedule="scaled_linear",
314
+ )
315
+ return schedular
316
+
317
+
318
+ def create_ldm_bert_config(original_config):
319
+ bert_params = original_config.model.parms.cond_stage_config.params
320
+ config = LDMBertConfig(
321
+ d_model=bert_params.n_embed,
322
+ encoder_layers=bert_params.n_layer,
323
+ encoder_ffn_dim=bert_params.n_embed * 4,
324
+ )
325
+ return config
326
+
327
+
328
+ def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
329
+ """
330
+ Takes a state dict and a config, and returns a converted checkpoint.
331
+ """
332
+
333
+ # extract state_dict for UNet
334
+ unet_state_dict = {}
335
+ keys = list(checkpoint.keys())
336
+
337
+ if controlnet:
338
+ unet_key = "control_model."
339
+ else:
340
+ unet_key = "model.diffusion_model."
341
+
342
+ # at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
343
+ if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
344
+ print(f"Checkpoint {path} has both EMA and non-EMA weights.")
345
+ print(
346
+ "In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
347
+ " weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
348
+ )
349
+ for key in keys:
350
+ if key.startswith("model.diffusion_model"):
351
+ flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
352
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
353
+ else:
354
+ if sum(k.startswith("model_ema") for k in keys) > 100:
355
+ print(
356
+ "In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
357
+ " weights (usually better for inference), please make sure to add the `--extract_ema` flag."
358
+ )
359
+
360
+ for key in keys:
361
+ if key.startswith(unet_key):
362
+ unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
363
+
364
+ new_checkpoint = {}
365
+
366
+ new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
367
+ new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
368
+ new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
369
+ new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
370
+
371
+ if config["class_embed_type"] is None:
372
+ # No parameters to port
373
+ ...
374
+ elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
375
+ new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
376
+ new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
377
+ new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
378
+ new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
379
+ else:
380
+ raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
381
+
382
+ new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
383
+ new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
384
+
385
+ if not controlnet:
386
+ new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
387
+ new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
388
+ new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
389
+ new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
390
+
391
+ # Retrieves the keys for the input blocks only
392
+ num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
393
+ input_blocks = {
394
+ layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
395
+ for layer_id in range(num_input_blocks)
396
+ }
397
+
398
+ # Retrieves the keys for the middle blocks only
399
+ num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
400
+ middle_blocks = {
401
+ layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
402
+ for layer_id in range(num_middle_blocks)
403
+ }
404
+
405
+ # Retrieves the keys for the output blocks only
406
+ num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
407
+ output_blocks = {
408
+ layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
409
+ for layer_id in range(num_output_blocks)
410
+ }
411
+
412
+ for i in range(1, num_input_blocks):
413
+ block_id = (i - 1) // (config["layers_per_block"] + 1)
414
+ layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
415
+
416
+ resnets = [
417
+ key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
418
+ ]
419
+ attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
420
+
421
+ if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
422
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
423
+ f"input_blocks.{i}.0.op.weight"
424
+ )
425
+ new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
426
+ f"input_blocks.{i}.0.op.bias"
427
+ )
428
+
429
+ paths = renew_resnet_paths(resnets)
430
+ meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
431
+ assign_to_checkpoint(
432
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
433
+ )
434
+
435
+ if len(attentions):
436
+ paths = renew_attention_paths(attentions)
437
+ meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
438
+ assign_to_checkpoint(
439
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
440
+ )
441
+
442
+ resnet_0 = middle_blocks[0]
443
+ attentions = middle_blocks[1]
444
+ resnet_1 = middle_blocks[2]
445
+
446
+ resnet_0_paths = renew_resnet_paths(resnet_0)
447
+ assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
448
+
449
+ resnet_1_paths = renew_resnet_paths(resnet_1)
450
+ assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
451
+
452
+ attentions_paths = renew_attention_paths(attentions)
453
+ meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
454
+ assign_to_checkpoint(
455
+ attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
456
+ )
457
+
458
+ for i in range(num_output_blocks):
459
+ block_id = i // (config["layers_per_block"] + 1)
460
+ layer_in_block_id = i % (config["layers_per_block"] + 1)
461
+ output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
462
+ output_block_list = {}
463
+
464
+ for layer in output_block_layers:
465
+ layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
466
+ if layer_id in output_block_list:
467
+ output_block_list[layer_id].append(layer_name)
468
+ else:
469
+ output_block_list[layer_id] = [layer_name]
470
+
471
+ if len(output_block_list) > 1:
472
+ resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
473
+ attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
474
+
475
+ resnet_0_paths = renew_resnet_paths(resnets)
476
+ paths = renew_resnet_paths(resnets)
477
+
478
+ meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
479
+ assign_to_checkpoint(
480
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
481
+ )
482
+
483
+ output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
484
+ if ["conv.bias", "conv.weight"] in output_block_list.values():
485
+ index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
486
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
487
+ f"output_blocks.{i}.{index}.conv.weight"
488
+ ]
489
+ new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
490
+ f"output_blocks.{i}.{index}.conv.bias"
491
+ ]
492
+
493
+ # Clear attentions as they have been attributed above.
494
+ if len(attentions) == 2:
495
+ attentions = []
496
+
497
+ if len(attentions):
498
+ paths = renew_attention_paths(attentions)
499
+ meta_path = {
500
+ "old": f"output_blocks.{i}.1",
501
+ "new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
502
+ }
503
+ assign_to_checkpoint(
504
+ paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
505
+ )
506
+ else:
507
+ resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
508
+ for path in resnet_0_paths:
509
+ old_path = ".".join(["output_blocks", str(i), path["old"]])
510
+ new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
511
+
512
+ new_checkpoint[new_path] = unet_state_dict[old_path]
513
+
514
+ if controlnet:
515
+ # conditioning embedding
516
+
517
+ orig_index = 0
518
+
519
+ new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
520
+ f"input_hint_block.{orig_index}.weight"
521
+ )
522
+ new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
523
+ f"input_hint_block.{orig_index}.bias"
524
+ )
525
+
526
+ orig_index += 2
527
+
528
+ diffusers_index = 0
529
+
530
+ while diffusers_index < 6:
531
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
532
+ f"input_hint_block.{orig_index}.weight"
533
+ )
534
+ new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
535
+ f"input_hint_block.{orig_index}.bias"
536
+ )
537
+ diffusers_index += 1
538
+ orig_index += 2
539
+
540
+ new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
541
+ f"input_hint_block.{orig_index}.weight"
542
+ )
543
+ new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
544
+ f"input_hint_block.{orig_index}.bias"
545
+ )
546
+
547
+ # down blocks
548
+ for i in range(num_input_blocks):
549
+ new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
550
+ new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
551
+
552
+ # mid block
553
+ new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
554
+ new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
555
+
556
+ return new_checkpoint
557
+
558
+
559
+ def convert_ldm_vae_checkpoint(checkpoint, config):
560
+ # extract state dict for VAE
561
+ vae_state_dict = {}
562
+ vae_key = "first_stage_model."
563
+ keys = list(checkpoint.keys())
564
+ for key in keys:
565
+ if key.startswith(vae_key):
566
+ vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
567
+
568
+ new_checkpoint = {}
569
+
570
+ new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
571
+ new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
572
+ new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
573
+ new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
574
+ new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
575
+ new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
576
+
577
+ new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
578
+ new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
579
+ new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
580
+ new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
581
+ new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
582
+ new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
583
+
584
+ new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
585
+ new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
586
+ new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
587
+ new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
588
+
589
+ # Retrieves the keys for the encoder down blocks only
590
+ num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
591
+ down_blocks = {
592
+ layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
593
+ }
594
+
595
+ # Retrieves the keys for the decoder up blocks only
596
+ num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
597
+ up_blocks = {
598
+ layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
599
+ }
600
+
601
+ for i in range(num_down_blocks):
602
+ resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
603
+
604
+ if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
605
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
606
+ f"encoder.down.{i}.downsample.conv.weight"
607
+ )
608
+ new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
609
+ f"encoder.down.{i}.downsample.conv.bias"
610
+ )
611
+
612
+ paths = renew_vae_resnet_paths(resnets)
613
+ meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
614
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
615
+
616
+ mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
617
+ num_mid_res_blocks = 2
618
+ for i in range(1, num_mid_res_blocks + 1):
619
+ resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
620
+
621
+ paths = renew_vae_resnet_paths(resnets)
622
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
623
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
624
+
625
+ mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
626
+ paths = renew_vae_attention_paths(mid_attentions)
627
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
628
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
629
+ conv_attn_to_linear(new_checkpoint)
630
+
631
+ for i in range(num_up_blocks):
632
+ block_id = num_up_blocks - 1 - i
633
+ resnets = [
634
+ key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
635
+ ]
636
+
637
+ if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
638
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
639
+ f"decoder.up.{block_id}.upsample.conv.weight"
640
+ ]
641
+ new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
642
+ f"decoder.up.{block_id}.upsample.conv.bias"
643
+ ]
644
+
645
+ paths = renew_vae_resnet_paths(resnets)
646
+ meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
647
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
648
+
649
+ mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
650
+ num_mid_res_blocks = 2
651
+ for i in range(1, num_mid_res_blocks + 1):
652
+ resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
653
+
654
+ paths = renew_vae_resnet_paths(resnets)
655
+ meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
656
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
657
+
658
+ mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
659
+ paths = renew_vae_attention_paths(mid_attentions)
660
+ meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
661
+ assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
662
+ conv_attn_to_linear(new_checkpoint)
663
+ return new_checkpoint
664
+
665
+
666
+ def convert_ldm_bert_checkpoint(checkpoint, config):
667
+ def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
668
+ hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
669
+ hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
670
+ hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
671
+
672
+ hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
673
+ hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
674
+
675
+ def _copy_linear(hf_linear, pt_linear):
676
+ hf_linear.weight = pt_linear.weight
677
+ hf_linear.bias = pt_linear.bias
678
+
679
+ def _copy_layer(hf_layer, pt_layer):
680
+ # copy layer norms
681
+ _copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
682
+ _copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
683
+
684
+ # copy attn
685
+ _copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
686
+
687
+ # copy MLP
688
+ pt_mlp = pt_layer[1][1]
689
+ _copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
690
+ _copy_linear(hf_layer.fc2, pt_mlp.net[2])
691
+
692
+ def _copy_layers(hf_layers, pt_layers):
693
+ for i, hf_layer in enumerate(hf_layers):
694
+ if i != 0:
695
+ i += i
696
+ pt_layer = pt_layers[i : i + 2]
697
+ _copy_layer(hf_layer, pt_layer)
698
+
699
+ hf_model = LDMBertModel(config).eval()
700
+
701
+ # copy embeds
702
+ hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
703
+ hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
704
+
705
+ # copy layer norm
706
+ _copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
707
+
708
+ # copy hidden layers
709
+ _copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
710
+
711
+ _copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
712
+
713
+ return hf_model
714
+
715
+
716
+ def convert_ldm_clip_checkpoint(checkpoint):
717
+ text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
718
+ keys = list(checkpoint.keys())
719
+
720
+ text_model_dict = {}
721
+
722
+ for key in keys:
723
+ if key.startswith("cond_stage_model.transformer"):
724
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
725
+
726
+ text_model.load_state_dict(text_model_dict)
727
+
728
+ return text_model
729
+
730
+
731
+ textenc_conversion_lst = [
732
+ ("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
733
+ ("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
734
+ ("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
735
+ ("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
736
+ ]
737
+ textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
738
+
739
+ textenc_transformer_conversion_lst = [
740
+ # (stable-diffusion, HF Diffusers)
741
+ ("resblocks.", "text_model.encoder.layers."),
742
+ ("ln_1", "layer_norm1"),
743
+ ("ln_2", "layer_norm2"),
744
+ (".c_fc.", ".fc1."),
745
+ (".c_proj.", ".fc2."),
746
+ (".attn", ".self_attn"),
747
+ ("ln_final.", "transformer.text_model.final_layer_norm."),
748
+ ("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
749
+ ("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
750
+ ]
751
+ protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
752
+ textenc_pattern = re.compile("|".join(protected.keys()))
753
+
754
+
755
+ def convert_paint_by_example_checkpoint(checkpoint):
756
+ config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
757
+ model = PaintByExampleImageEncoder(config)
758
+
759
+ keys = list(checkpoint.keys())
760
+
761
+ text_model_dict = {}
762
+
763
+ for key in keys:
764
+ if key.startswith("cond_stage_model.transformer"):
765
+ text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
766
+
767
+ # load clip vision
768
+ model.model.load_state_dict(text_model_dict)
769
+
770
+ # load mapper
771
+ keys_mapper = {
772
+ k[len("cond_stage_model.mapper.res") :]: v
773
+ for k, v in checkpoint.items()
774
+ if k.startswith("cond_stage_model.mapper")
775
+ }
776
+
777
+ MAPPING = {
778
+ "attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
779
+ "attn.c_proj": ["attn1.to_out.0"],
780
+ "ln_1": ["norm1"],
781
+ "ln_2": ["norm3"],
782
+ "mlp.c_fc": ["ff.net.0.proj"],
783
+ "mlp.c_proj": ["ff.net.2"],
784
+ }
785
+
786
+ mapped_weights = {}
787
+ for key, value in keys_mapper.items():
788
+ prefix = key[: len("blocks.i")]
789
+ suffix = key.split(prefix)[-1].split(".")[-1]
790
+ name = key.split(prefix)[-1].split(suffix)[0][1:-1]
791
+ mapped_names = MAPPING[name]
792
+
793
+ num_splits = len(mapped_names)
794
+ for i, mapped_name in enumerate(mapped_names):
795
+ new_name = ".".join([prefix, mapped_name, suffix])
796
+ shape = value.shape[0] // num_splits
797
+ mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
798
+
799
+ model.mapper.load_state_dict(mapped_weights)
800
+
801
+ # load final layer norm
802
+ model.final_layer_norm.load_state_dict(
803
+ {
804
+ "bias": checkpoint["cond_stage_model.final_ln.bias"],
805
+ "weight": checkpoint["cond_stage_model.final_ln.weight"],
806
+ }
807
+ )
808
+
809
+ # load final proj
810
+ model.proj_out.load_state_dict(
811
+ {
812
+ "bias": checkpoint["proj_out.bias"],
813
+ "weight": checkpoint["proj_out.weight"],
814
+ }
815
+ )
816
+
817
+ # load uncond vector
818
+ model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
819
+ return model
820
+
821
+
822
+ def convert_open_clip_checkpoint(checkpoint):
823
+ text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
824
+
825
+ keys = list(checkpoint.keys())
826
+
827
+ text_model_dict = {}
828
+
829
+ if "cond_stage_model.model.text_projection" in checkpoint:
830
+ d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
831
+ else:
832
+ d_model = 1024
833
+
834
+ text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
835
+
836
+ for key in keys:
837
+ if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
838
+ continue
839
+ if key in textenc_conversion_map:
840
+ text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
841
+ if key.startswith("cond_stage_model.model.transformer."):
842
+ new_key = key[len("cond_stage_model.model.transformer.") :]
843
+ if new_key.endswith(".in_proj_weight"):
844
+ new_key = new_key[: -len(".in_proj_weight")]
845
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
846
+ text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
847
+ text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
848
+ text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
849
+ elif new_key.endswith(".in_proj_bias"):
850
+ new_key = new_key[: -len(".in_proj_bias")]
851
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
852
+ text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
853
+ text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
854
+ text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
855
+ else:
856
+ new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
857
+
858
+ text_model_dict[new_key] = checkpoint[key]
859
+
860
+ text_model.load_state_dict(text_model_dict)
861
+
862
+ return text_model
863
+
864
+
865
+ def stable_unclip_image_encoder(original_config):
866
+ """
867
+ Returns the image processor and clip image encoder for the img2img unclip pipeline.
868
+
869
+ We currently know of two types of stable unclip models which separately use the clip and the openclip image
870
+ encoders.
871
+ """
872
+
873
+ image_embedder_config = original_config.model.params.embedder_config
874
+
875
+ sd_clip_image_embedder_class = image_embedder_config.target
876
+ sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
877
+
878
+ if sd_clip_image_embedder_class == "ClipImageEmbedder":
879
+ clip_model_name = image_embedder_config.params.model
880
+
881
+ if clip_model_name == "ViT-L/14":
882
+ feature_extractor = CLIPImageProcessor()
883
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
884
+ else:
885
+ raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
886
+
887
+ elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
888
+ feature_extractor = CLIPImageProcessor()
889
+ image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
890
+ else:
891
+ raise NotImplementedError(
892
+ f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
893
+ )
894
+
895
+ return feature_extractor, image_encoder
896
+
897
+
898
+ def stable_unclip_image_noising_components(
899
+ original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
900
+ ):
901
+ """
902
+ Returns the noising components for the img2img and txt2img unclip pipelines.
903
+
904
+ Converts the stability noise augmentor into
905
+ 1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
906
+ 2. a `DDPMScheduler` for holding the noise schedule
907
+
908
+ If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
909
+ """
910
+ noise_aug_config = original_config.model.params.noise_aug_config
911
+ noise_aug_class = noise_aug_config.target
912
+ noise_aug_class = noise_aug_class.split(".")[-1]
913
+
914
+ if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
915
+ noise_aug_config = noise_aug_config.params
916
+ embedding_dim = noise_aug_config.timestep_dim
917
+ max_noise_level = noise_aug_config.noise_schedule_config.timesteps
918
+ beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
919
+
920
+ image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
921
+ image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
922
+
923
+ if "clip_stats_path" in noise_aug_config:
924
+ if clip_stats_path is None:
925
+ raise ValueError("This stable unclip config requires a `clip_stats_path`")
926
+
927
+ clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
928
+ clip_mean = clip_mean[None, :]
929
+ clip_std = clip_std[None, :]
930
+
931
+ clip_stats_state_dict = {
932
+ "mean": clip_mean,
933
+ "std": clip_std,
934
+ }
935
+
936
+ image_normalizer.load_state_dict(clip_stats_state_dict)
937
+ else:
938
+ raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
939
+
940
+ return image_normalizer, image_noising_scheduler
941
+
942
+
943
+ def convert_controlnet_checkpoint(
944
+ checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
945
+ ):
946
+ ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
947
+ ctrlnet_config["upcast_attention"] = upcast_attention
948
+
949
+ ctrlnet_config.pop("sample_size")
950
+
951
+ controlnet_model = ControlNetModel(**ctrlnet_config)
952
+
953
+ converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
954
+ checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
955
+ )
956
+
957
+ controlnet_model.load_state_dict(converted_ctrl_checkpoint)
958
+
959
+ return controlnet_model
utils/convert_lora_safetensor_to_diffusers.py ADDED
@@ -0,0 +1,154 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """ Conversion script for the LoRA's safetensors checkpoints. """
17
+
18
+ import argparse
19
+
20
+ import torch
21
+ from safetensors.torch import load_file
22
+
23
+ from diffusers import StableDiffusionPipeline
24
+ import pdb
25
+
26
+
27
+
28
+ def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
29
+ # directly update weight in diffusers model
30
+ for key in state_dict:
31
+ # only process lora down key
32
+ if "up." in key: continue
33
+
34
+ up_key = key.replace(".down.", ".up.")
35
+ model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
36
+ model_key = model_key.replace("to_out.", "to_out.0.")
37
+ layer_infos = model_key.split(".")[:-1]
38
+
39
+ curr_layer = pipeline.unet
40
+ while len(layer_infos) > 0:
41
+ temp_name = layer_infos.pop(0)
42
+ curr_layer = curr_layer.__getattr__(temp_name)
43
+
44
+ weight_down = state_dict[key]
45
+ weight_up = state_dict[up_key]
46
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
47
+
48
+ return pipeline
49
+
50
+
51
+
52
+ def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
53
+ # load base model
54
+ # pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
55
+
56
+ # load LoRA weight from .safetensors
57
+ # state_dict = load_file(checkpoint_path)
58
+
59
+ visited = []
60
+
61
+ # directly update weight in diffusers model
62
+ for key in state_dict:
63
+ # it is suggested to print out the key, it usually will be something like below
64
+ # "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
65
+
66
+ # as we have set the alpha beforehand, so just skip
67
+ if ".alpha" in key or key in visited:
68
+ continue
69
+
70
+ if "text" in key:
71
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
72
+ curr_layer = pipeline.text_encoder
73
+ else:
74
+ layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
75
+ curr_layer = pipeline.unet
76
+
77
+ # find the target layer
78
+ temp_name = layer_infos.pop(0)
79
+ while len(layer_infos) > -1:
80
+ try:
81
+ curr_layer = curr_layer.__getattr__(temp_name)
82
+ if len(layer_infos) > 0:
83
+ temp_name = layer_infos.pop(0)
84
+ elif len(layer_infos) == 0:
85
+ break
86
+ except Exception:
87
+ if len(temp_name) > 0:
88
+ temp_name += "_" + layer_infos.pop(0)
89
+ else:
90
+ temp_name = layer_infos.pop(0)
91
+
92
+ pair_keys = []
93
+ if "lora_down" in key:
94
+ pair_keys.append(key.replace("lora_down", "lora_up"))
95
+ pair_keys.append(key)
96
+ else:
97
+ pair_keys.append(key)
98
+ pair_keys.append(key.replace("lora_up", "lora_down"))
99
+
100
+ # update weight
101
+ if len(state_dict[pair_keys[0]].shape) == 4:
102
+ weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
103
+ weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
104
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
105
+ else:
106
+ weight_up = state_dict[pair_keys[0]].to(torch.float32)
107
+ weight_down = state_dict[pair_keys[1]].to(torch.float32)
108
+ curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
109
+
110
+ # update visited list
111
+ for item in pair_keys:
112
+ visited.append(item)
113
+
114
+ return pipeline
115
+
116
+
117
+ if __name__ == "__main__":
118
+ parser = argparse.ArgumentParser()
119
+
120
+ parser.add_argument(
121
+ "--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
122
+ )
123
+ parser.add_argument(
124
+ "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
125
+ )
126
+ parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
127
+ parser.add_argument(
128
+ "--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
129
+ )
130
+ parser.add_argument(
131
+ "--lora_prefix_text_encoder",
132
+ default="lora_te",
133
+ type=str,
134
+ help="The prefix of text encoder weight in safetensors",
135
+ )
136
+ parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
137
+ parser.add_argument(
138
+ "--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
139
+ )
140
+ parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
141
+
142
+ args = parser.parse_args()
143
+
144
+ base_model_path = args.base_model_path
145
+ checkpoint_path = args.checkpoint_path
146
+ dump_path = args.dump_path
147
+ lora_prefix_unet = args.lora_prefix_unet
148
+ lora_prefix_text_encoder = args.lora_prefix_text_encoder
149
+ alpha = args.alpha
150
+
151
+ pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
152
+
153
+ pipe = pipe.to(args.device)
154
+ pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
utils/diffuser_utils.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Util functions based on Diffuser framework.
3
+ """
4
+
5
+
6
+ import os
7
+ import torch
8
+ import cv2
9
+ import numpy as np
10
+
11
+ import torch.nn.functional as F
12
+ from tqdm import tqdm
13
+ from PIL import Image
14
+ from torchvision.utils import save_image
15
+ from torchvision.io import read_image
16
+
17
+ from diffusers import StableDiffusionPipeline
18
+
19
+ from pytorch_lightning import seed_everything
20
+
21
+
22
+ class MasaCtrlPipeline(StableDiffusionPipeline):
23
+
24
+ def next_step(
25
+ self,
26
+ model_output: torch.FloatTensor,
27
+ timestep: int,
28
+ x: torch.FloatTensor,
29
+ eta=0.,
30
+ verbose=False
31
+ ):
32
+ """
33
+ Inverse sampling for DDIM Inversion
34
+ """
35
+ if verbose:
36
+ print("timestep: ", timestep)
37
+ next_step = timestep
38
+ timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
39
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
40
+ alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
41
+ beta_prod_t = 1 - alpha_prod_t
42
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
43
+ pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
44
+ x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
45
+ return x_next, pred_x0
46
+
47
+ def step(
48
+ self,
49
+ model_output: torch.FloatTensor,
50
+ timestep: int,
51
+ x: torch.FloatTensor,
52
+ eta: float=0.0,
53
+ verbose=False,
54
+ ):
55
+ """
56
+ predict the sampe the next step in the denoise process.
57
+ """
58
+ prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
59
+ alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
60
+ alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
61
+ beta_prod_t = 1 - alpha_prod_t
62
+ pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
63
+ pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
64
+ x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
65
+ return x_prev, pred_x0
66
+
67
+ @torch.no_grad()
68
+ def image2latent(self, image):
69
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
70
+ if type(image) is Image:
71
+ image = np.array(image)
72
+ image = torch.from_numpy(image).float() / 127.5 - 1
73
+ image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
74
+ # input image density range [-1, 1]
75
+ latents = self.vae.encode(image)['latent_dist'].mean
76
+ latents = latents * 0.18215
77
+ return latents
78
+
79
+ @torch.no_grad()
80
+ def latent2image(self, latents, return_type='np'):
81
+ latents = 1 / 0.18215 * latents.detach()
82
+ image = self.vae.decode(latents)['sample']
83
+ if return_type == 'np':
84
+ image = (image / 2 + 0.5).clamp(0, 1)
85
+ image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
86
+ image = (image * 255).astype(np.uint8)
87
+ elif return_type == "pt":
88
+ image = (image / 2 + 0.5).clamp(0, 1)
89
+
90
+ return image
91
+
92
+ def latent2image_grad(self, latents):
93
+ latents = 1 / 0.18215 * latents
94
+ image = self.vae.decode(latents)['sample']
95
+
96
+ return image # range [-1, 1]
97
+
98
+ @torch.no_grad()
99
+ def __call__(
100
+ self,
101
+ prompt,
102
+ batch_size=1,
103
+ height=512,
104
+ width=512,
105
+ num_inference_steps=50,
106
+ guidance_scale=7.5,
107
+ eta=0.0,
108
+ latents=None,
109
+ unconditioning=None,
110
+ neg_prompt=None,
111
+ ref_intermediate_latents=None,
112
+ return_intermediates=False,
113
+ **kwds):
114
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
115
+ if isinstance(prompt, list):
116
+ batch_size = len(prompt)
117
+ elif isinstance(prompt, str):
118
+ if batch_size > 1:
119
+ prompt = [prompt] * batch_size
120
+
121
+ # text embeddings
122
+ text_input = self.tokenizer(
123
+ prompt,
124
+ padding="max_length",
125
+ max_length=77,
126
+ return_tensors="pt"
127
+ )
128
+
129
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
130
+ print("input text embeddings :", text_embeddings.shape)
131
+ if kwds.get("dir"):
132
+ dir = text_embeddings[-2] - text_embeddings[-1]
133
+ u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
134
+ text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
135
+ print(u.shape)
136
+ print(v.shape)
137
+
138
+ # define initial latents
139
+ latents_shape = (batch_size, self.unet.config.in_channels, height//8, width//8)
140
+ if latents is None:
141
+ latents = torch.randn(latents_shape, device=DEVICE)
142
+ else:
143
+ assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
144
+
145
+ # unconditional embedding for classifier free guidance
146
+ if guidance_scale > 1.:
147
+ max_length = text_input.input_ids.shape[-1]
148
+ if neg_prompt:
149
+ uc_text = neg_prompt
150
+ else:
151
+ uc_text = ""
152
+ # uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
153
+ unconditional_input = self.tokenizer(
154
+ [uc_text] * batch_size,
155
+ padding="max_length",
156
+ max_length=77,
157
+ return_tensors="pt"
158
+ )
159
+ # unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
160
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
161
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
162
+
163
+ print("latents shape: ", latents.shape)
164
+ # iterative sampling
165
+ self.scheduler.set_timesteps(num_inference_steps)
166
+ # print("Valid timesteps: ", reversed(self.scheduler.timesteps))
167
+ latents_list = [latents]
168
+ pred_x0_list = [latents]
169
+ for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
170
+ if ref_intermediate_latents is not None:
171
+ # note that the batch_size >= 2
172
+ latents_ref = ref_intermediate_latents[-1 - i]
173
+ _, latents_cur = latents.chunk(2)
174
+ latents = torch.cat([latents_ref, latents_cur])
175
+
176
+ if guidance_scale > 1.:
177
+ model_inputs = torch.cat([latents] * 2)
178
+ else:
179
+ model_inputs = latents
180
+ if unconditioning is not None and isinstance(unconditioning, list):
181
+ _, text_embeddings = text_embeddings.chunk(2)
182
+ text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
183
+ # predict tghe noise
184
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
185
+ if guidance_scale > 1.:
186
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
187
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
188
+ # compute the previous noise sample x_t -> x_t-1
189
+ latents, pred_x0 = self.step(noise_pred, t, latents)
190
+ latents_list.append(latents)
191
+ pred_x0_list.append(pred_x0)
192
+
193
+ image = self.latent2image(latents, return_type="pt")
194
+ if return_intermediates:
195
+ pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
196
+ latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
197
+ return image, pred_x0_list, latents_list
198
+ return image
199
+
200
+ @torch.no_grad()
201
+ def invert(
202
+ self,
203
+ image: torch.Tensor,
204
+ prompt,
205
+ num_inference_steps=50,
206
+ guidance_scale=7.5,
207
+ eta=0.0,
208
+ return_intermediates=False,
209
+ **kwds):
210
+ """
211
+ invert a real image into noise map with determinisc DDIM inversion
212
+ """
213
+ DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
214
+ batch_size = image.shape[0]
215
+ if isinstance(prompt, list):
216
+ if batch_size == 1:
217
+ image = image.expand(len(prompt), -1, -1, -1)
218
+ elif isinstance(prompt, str):
219
+ if batch_size > 1:
220
+ prompt = [prompt] * batch_size
221
+
222
+ # text embeddings
223
+ text_input = self.tokenizer(
224
+ prompt,
225
+ padding="max_length",
226
+ max_length=77,
227
+ return_tensors="pt"
228
+ )
229
+ text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
230
+ print("input text embeddings :", text_embeddings.shape)
231
+ # define initial latents
232
+ latents = self.image2latent(image)
233
+ start_latents = latents
234
+ # print(latents)
235
+ # exit()
236
+ # unconditional embedding for classifier free guidance
237
+ if guidance_scale > 1.:
238
+ max_length = text_input.input_ids.shape[-1]
239
+ unconditional_input = self.tokenizer(
240
+ [""] * batch_size,
241
+ padding="max_length",
242
+ max_length=77,
243
+ return_tensors="pt"
244
+ )
245
+ unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
246
+ text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
247
+
248
+ print("latents shape: ", latents.shape)
249
+ # interative sampling
250
+ self.scheduler.set_timesteps(num_inference_steps)
251
+ print("Valid timesteps: ", reversed(self.scheduler.timesteps))
252
+ # print("attributes: ", self.scheduler.__dict__)
253
+ latents_list = [latents]
254
+ pred_x0_list = [latents]
255
+ for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
256
+ if guidance_scale > 1.:
257
+ model_inputs = torch.cat([latents] * 2)
258
+ else:
259
+ model_inputs = latents
260
+
261
+ # predict the noise
262
+ noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
263
+ if guidance_scale > 1.:
264
+ noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
265
+ noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
266
+ # compute the previous noise sample x_t-1 -> x_t
267
+ latents, pred_x0 = self.next_step(noise_pred, t, latents)
268
+ latents_list.append(latents)
269
+ pred_x0_list.append(pred_x0)
270
+
271
+ if return_intermediates:
272
+ # return the intermediate laters during inversion
273
+ # pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
274
+ return latents, latents_list
275
+ return latents, start_latents
utils/free_lunch_utils.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.fft as fft
3
+ from diffusers.models.unet_2d_condition import logger
4
+ from diffusers.utils import is_torch_version
5
+ from typing import Any, Dict, List, Optional, Tuple, Union
6
+ import torch.nn.functional as F
7
+
8
+ def isinstance_str(x: object, cls_name: str):
9
+ """
10
+ Checks whether x has any class *named* cls_name in its ancestry.
11
+ Doesn't require access to the class's implementation.
12
+
13
+ Useful for patching!
14
+ """
15
+
16
+ for _cls in x.__class__.__mro__:
17
+ if _cls.__name__ == cls_name:
18
+ return True
19
+
20
+ return False
21
+
22
+
23
+ def Fourier_filter(x, threshold, scale):
24
+ dtype = x.dtype
25
+ x = x.type(torch.float32)
26
+ # FFT
27
+ x_freq = fft.fftn(x, dim=(-2, -1))
28
+ x_freq = fft.fftshift(x_freq, dim=(-2, -1))
29
+
30
+ B, C, H, W = x_freq.shape
31
+ mask = torch.ones((B, C, H, W)).cuda()
32
+
33
+ crow, ccol = H // 2, W //2
34
+ mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
35
+ x_freq = x_freq * mask
36
+
37
+ # IFFT
38
+ x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
39
+ x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
40
+
41
+ x_filtered = x_filtered.type(dtype)
42
+ return x_filtered
43
+
44
+
45
+ def register_upblock2d(model):
46
+ def up_forward(self):
47
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale=None):
48
+ for resnet in self.resnets:
49
+ # pop res hidden states
50
+ res_hidden_states = res_hidden_states_tuple[-1]
51
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
52
+ #print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
53
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
54
+
55
+ if self.training and self.gradient_checkpointing:
56
+
57
+ def create_custom_forward(module):
58
+ def custom_forward(*inputs):
59
+ return module(*inputs)
60
+
61
+ return custom_forward
62
+
63
+ if is_torch_version(">=", "1.11.0"):
64
+ hidden_states = torch.utils.checkpoint.checkpoint(
65
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
66
+ )
67
+ else:
68
+ hidden_states = torch.utils.checkpoint.checkpoint(
69
+ create_custom_forward(resnet), hidden_states, temb
70
+ )
71
+ else:
72
+ hidden_states = resnet(hidden_states, temb)
73
+
74
+ if self.upsamplers is not None:
75
+ for upsampler in self.upsamplers:
76
+ hidden_states = upsampler(hidden_states, upsample_size)
77
+
78
+ return hidden_states
79
+
80
+ return forward
81
+
82
+ for i, upsample_block in enumerate(model.unet.up_blocks):
83
+ if isinstance_str(upsample_block, "UpBlock2D"):
84
+ upsample_block.forward = up_forward(upsample_block)
85
+
86
+
87
+ def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2,source_mask=None):
88
+ def up_forward(self):
89
+ def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale=None):
90
+ for resnet in self.resnets:
91
+ # pop res hidden states
92
+ res_hidden_states = res_hidden_states_tuple[-1]
93
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
94
+ #print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
95
+
96
+ if self.source_mask is not None:
97
+ spatial_mask_source = F.interpolate(self.source_mask, (hidden_states.shape[2], hidden_states.shape[3]))
98
+ spatial_mask_source_b1 = spatial_mask_source * self.b1 + (1 - spatial_mask_source)
99
+ spatial_mask_source_b2 = spatial_mask_source * self.b2 + (1 - spatial_mask_source)
100
+ # --------------- FreeU code -----------------------
101
+ # Only operate on the first two stages
102
+ if hidden_states.shape[1] == 1280:
103
+ if self.source_mask is not None:
104
+ #where in mask = 0, set hidden states unchanged
105
+ hidden_states[:,:640] = hidden_states[:,:640] * spatial_mask_source_b1
106
+
107
+ else:
108
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
109
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
110
+ if hidden_states.shape[1] == 640:
111
+
112
+ if self.source_mask is not None:
113
+ hidden_states[:,:320] = hidden_states[:,:320] * spatial_mask_source_b2
114
+ else:
115
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
116
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
117
+ # ---------------------------------------------------------
118
+
119
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
120
+
121
+ if self.training and self.gradient_checkpointing:
122
+
123
+ def create_custom_forward(module):
124
+ def custom_forward(*inputs):
125
+ return module(*inputs)
126
+
127
+ return custom_forward
128
+
129
+ if is_torch_version(">=", "1.11.0"):
130
+ hidden_states = torch.utils.checkpoint.checkpoint(
131
+ create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
132
+ )
133
+ else:
134
+ hidden_states = torch.utils.checkpoint.checkpoint(
135
+ create_custom_forward(resnet), hidden_states, temb
136
+ )
137
+ else:
138
+ hidden_states = resnet(hidden_states, temb)
139
+
140
+ if self.upsamplers is not None:
141
+ for upsampler in self.upsamplers:
142
+ hidden_states = upsampler(hidden_states, upsample_size)
143
+
144
+ return hidden_states
145
+
146
+ return forward
147
+
148
+ for i, upsample_block in enumerate(model.unet.up_blocks):
149
+ if isinstance_str(upsample_block, "UpBlock2D"):
150
+ upsample_block.forward = up_forward(upsample_block)
151
+ setattr(upsample_block, 'b1', b1)
152
+ setattr(upsample_block, 'b2', b2)
153
+ setattr(upsample_block, 's1', s1)
154
+ setattr(upsample_block, 's2', s2)
155
+ setattr(upsample_block, 'source_mask', source_mask)
156
+
157
+ def register_crossattn_upblock2d(model):
158
+ def up_forward(self):
159
+ def forward(
160
+ hidden_states: torch.FloatTensor,
161
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
162
+ temb: Optional[torch.FloatTensor] = None,
163
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
164
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
165
+ upsample_size: Optional[int] = None,
166
+ attention_mask: Optional[torch.FloatTensor] = None,
167
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
168
+ ):
169
+ for resnet, attn in zip(self.resnets, self.attentions):
170
+ # pop res hidden states
171
+ #print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
172
+ res_hidden_states = res_hidden_states_tuple[-1]
173
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
174
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
175
+
176
+ if self.training and self.gradient_checkpointing:
177
+
178
+ def create_custom_forward(module, return_dict=None):
179
+ def custom_forward(*inputs):
180
+ if return_dict is not None:
181
+ return module(*inputs, return_dict=return_dict)
182
+ else:
183
+ return module(*inputs)
184
+
185
+ return custom_forward
186
+
187
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
188
+ hidden_states = torch.utils.checkpoint.checkpoint(
189
+ create_custom_forward(resnet),
190
+ hidden_states,
191
+ temb,
192
+ **ckpt_kwargs,
193
+ )
194
+ hidden_states = torch.utils.checkpoint.checkpoint(
195
+ create_custom_forward(attn, return_dict=False),
196
+ hidden_states,
197
+ encoder_hidden_states,
198
+ None, # timestep
199
+ None, # class_labels
200
+ cross_attention_kwargs,
201
+ attention_mask,
202
+ encoder_attention_mask,
203
+ **ckpt_kwargs,
204
+ )[0]
205
+ else:
206
+ hidden_states = resnet(hidden_states, temb)
207
+ hidden_states = attn(
208
+ hidden_states,
209
+ encoder_hidden_states=encoder_hidden_states,
210
+ cross_attention_kwargs=cross_attention_kwargs,
211
+ attention_mask=attention_mask,
212
+ encoder_attention_mask=encoder_attention_mask,
213
+ return_dict=False,
214
+ )[0]
215
+
216
+ if self.upsamplers is not None:
217
+ for upsampler in self.upsamplers:
218
+ hidden_states = upsampler(hidden_states, upsample_size)
219
+
220
+ return hidden_states
221
+
222
+ return forward
223
+
224
+ for i, upsample_block in enumerate(model.unet.up_blocks):
225
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
226
+ upsample_block.forward = up_forward(upsample_block)
227
+
228
+
229
+ def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2,source_mask=None):
230
+ def up_forward(self):
231
+ def forward(
232
+ hidden_states: torch.FloatTensor,
233
+ res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
234
+ temb: Optional[torch.FloatTensor] = None,
235
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
236
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
237
+ upsample_size: Optional[int] = None,
238
+ attention_mask: Optional[torch.FloatTensor] = None,
239
+ encoder_attention_mask: Optional[torch.FloatTensor] = None,
240
+ ):
241
+
242
+ if self.source_mask is not None:
243
+
244
+ spatial_mask_source = F.interpolate(self.source_mask, (hidden_states.shape[2], hidden_states.shape[3]))
245
+ spatial_mask_source_b1 = spatial_mask_source * self.b1 + (1 - spatial_mask_source)
246
+ spatial_mask_source_b2 = spatial_mask_source * self.b2 + (1 - spatial_mask_source)
247
+ # print(f"source mask is not none, {spatial_mask_source_b1.shape} with min {spatial_mask_source_b1.min()}", )
248
+
249
+ for resnet, attn in zip(self.resnets, self.attentions):
250
+ # pop res hidden states
251
+ #print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
252
+ res_hidden_states = res_hidden_states_tuple[-1]
253
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
254
+
255
+ # --------------- FreeU code -----------------------
256
+ # Only operate on the first two stages
257
+ if hidden_states.shape[1] == 1280:
258
+ if self.source_mask is not None:
259
+ #where in mask = 0, set hidden states unchanged
260
+ hidden_states[:,:640] = hidden_states[:,:640] * spatial_mask_source_b1
261
+
262
+ else:
263
+ hidden_states[:,:640] = hidden_states[:,:640] * self.b1
264
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
265
+ if hidden_states.shape[1] == 640:
266
+ if self.source_mask is not None:
267
+ hidden_states[:,:320] = hidden_states[:,:320] * spatial_mask_source_b2
268
+ else:
269
+ hidden_states[:,:320] = hidden_states[:,:320] * self.b2
270
+ res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
271
+ # ---------------------------------------------------------
272
+
273
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
274
+
275
+ if self.training and self.gradient_checkpointing:
276
+
277
+ def create_custom_forward(module, return_dict=None):
278
+ def custom_forward(*inputs):
279
+ if return_dict is not None:
280
+ return module(*inputs, return_dict=return_dict)
281
+ else:
282
+ return module(*inputs)
283
+
284
+ return custom_forward
285
+
286
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
287
+ hidden_states = torch.utils.checkpoint.checkpoint(
288
+ create_custom_forward(resnet),
289
+ hidden_states,
290
+ temb,
291
+ **ckpt_kwargs,
292
+ )
293
+ hidden_states = torch.utils.checkpoint.checkpoint(
294
+ create_custom_forward(attn, return_dict=False),
295
+ hidden_states,
296
+ encoder_hidden_states,
297
+ None, # timestep
298
+ None, # class_labels
299
+ cross_attention_kwargs,
300
+ attention_mask,
301
+ encoder_attention_mask,
302
+ **ckpt_kwargs,
303
+ )[0]
304
+ else:
305
+ hidden_states = resnet(hidden_states, temb)
306
+ # hidden_states = attn(
307
+ # hidden_states,
308
+ # encoder_hidden_states=encoder_hidden_states,
309
+ # cross_attention_kwargs=cross_attention_kwargs,
310
+ # encoder_attention_mask=encoder_attention_mask,
311
+ # return_dict=False,
312
+ # )[0]
313
+ hidden_states = attn(
314
+ hidden_states,
315
+ encoder_hidden_states=encoder_hidden_states,
316
+ cross_attention_kwargs=cross_attention_kwargs,
317
+ )[0]
318
+
319
+ if self.upsamplers is not None:
320
+ for upsampler in self.upsamplers:
321
+ hidden_states = upsampler(hidden_states, upsample_size)
322
+
323
+ return hidden_states
324
+
325
+ return forward
326
+
327
+ for i, upsample_block in enumerate(model.unet.up_blocks):
328
+ if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
329
+ upsample_block.forward = up_forward(upsample_block)
330
+ setattr(upsample_block, 'b1', b1)
331
+ setattr(upsample_block, 'b2', b2)
332
+ setattr(upsample_block, 's1', s1)
333
+ setattr(upsample_block, 's2', s2)
334
+ setattr(upsample_block, 'source_mask', source_mask)
utils/masactrl_utils.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import cv2
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+ from typing import Optional, Union, Tuple, List, Callable, Dict
9
+
10
+ from torchvision.utils import save_image
11
+ from einops import rearrange, repeat
12
+
13
+
14
+ class AttentionBase:
15
+ def __init__(self):
16
+ self.cur_step = 0
17
+ self.num_att_layers = -1
18
+ self.cur_att_layer = 0
19
+
20
+ def after_step(self):
21
+ pass
22
+
23
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
24
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
25
+ self.cur_att_layer += 1
26
+ if self.cur_att_layer == self.num_att_layers:
27
+ self.cur_att_layer = 0
28
+ self.cur_step += 1
29
+ # after step
30
+ self.after_step()
31
+ return out
32
+
33
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
34
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
35
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
36
+ return out
37
+
38
+ def reset(self):
39
+ self.cur_step = 0
40
+ self.cur_att_layer = 0
41
+
42
+
43
+ class AttentionStore(AttentionBase):
44
+ def __init__(self, res=[32], min_step=0, max_step=1000):
45
+ super().__init__()
46
+ self.res = res
47
+ self.min_step = min_step
48
+ self.max_step = max_step
49
+ self.valid_steps = 0
50
+
51
+ self.self_attns = [] # store the all attns
52
+ self.cross_attns = []
53
+
54
+ self.self_attns_step = [] # store the attns in each step
55
+ self.cross_attns_step = []
56
+
57
+ def after_step(self):
58
+ if self.cur_step > self.min_step and self.cur_step < self.max_step:
59
+ self.valid_steps += 1
60
+ if len(self.self_attns) == 0:
61
+ self.self_attns = self.self_attns_step
62
+ self.cross_attns = self.cross_attns_step
63
+ else:
64
+ for i in range(len(self.self_attns)):
65
+ self.self_attns[i] += self.self_attns_step[i]
66
+ self.cross_attns[i] += self.cross_attns_step[i]
67
+ self.self_attns_step.clear()
68
+ self.cross_attns_step.clear()
69
+
70
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
71
+ if attn.shape[1] <= 64 ** 2: # avoid OOM
72
+ if is_cross:
73
+ self.cross_attns_step.append(attn)
74
+ else:
75
+ self.self_attns_step.append(attn)
76
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
77
+
78
+
79
+ def regiter_attention_editor_diffusers(model, editor: AttentionBase):
80
+ """
81
+ Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
82
+ """
83
+ def ca_forward(self, place_in_unet):
84
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
85
+ """
86
+ The attention is similar to the original implementation of LDM CrossAttention class
87
+ except adding some modifications on the attention
88
+ """
89
+ if encoder_hidden_states is not None:
90
+ context = encoder_hidden_states
91
+ if attention_mask is not None:
92
+ mask = attention_mask
93
+
94
+ to_out = self.to_out
95
+ if isinstance(to_out, nn.modules.container.ModuleList):
96
+ to_out = self.to_out[0]
97
+ else:
98
+ to_out = self.to_out
99
+
100
+ h = self.heads
101
+ q = self.to_q(x)
102
+ is_cross = context is not None
103
+ context = context if is_cross else x
104
+ k = self.to_k(context)
105
+ v = self.to_v(context)
106
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
107
+
108
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
109
+
110
+ if mask is not None:
111
+ mask = rearrange(mask, 'b ... -> b (...)')
112
+ max_neg_value = -torch.finfo(sim.dtype).max
113
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
114
+ mask = mask[:, None, :].repeat(h, 1, 1)
115
+ sim.masked_fill_(~mask, max_neg_value)
116
+
117
+ attn = sim.softmax(dim=-1)
118
+ # the only difference
119
+ out = editor(
120
+ q, k, v, sim, attn, is_cross, place_in_unet,
121
+ self.heads, scale=self.scale)
122
+
123
+ return to_out(out)
124
+
125
+ return forward
126
+
127
+ def register_editor(net, count, place_in_unet):
128
+ for name, subnet in net.named_children():
129
+ if net.__class__.__name__ == 'Attention': # spatial Transformer layer
130
+ net.forward = ca_forward(net, place_in_unet)
131
+ return count + 1
132
+ elif hasattr(net, 'children'):
133
+ count = register_editor(subnet, count, place_in_unet)
134
+ return count
135
+
136
+ cross_att_count = 0
137
+ for net_name, net in model.unet.named_children():
138
+ if "down" in net_name:
139
+ cross_att_count += register_editor(net, 0, "down")
140
+ elif "mid" in net_name:
141
+ cross_att_count += register_editor(net, 0, "mid")
142
+ elif "up" in net_name:
143
+ cross_att_count += register_editor(net, 0, "up")
144
+ editor.num_att_layers = cross_att_count
145
+
146
+
147
+ def regiter_attention_editor_ldm(model, editor: AttentionBase):
148
+ """
149
+ Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
150
+ """
151
+ def ca_forward(self, place_in_unet):
152
+ def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
153
+ """
154
+ The attention is similar to the original implementation of LDM CrossAttention class
155
+ except adding some modifications on the attention
156
+ """
157
+ if encoder_hidden_states is not None:
158
+ context = encoder_hidden_states
159
+ if attention_mask is not None:
160
+ mask = attention_mask
161
+
162
+ to_out = self.to_out
163
+ if isinstance(to_out, nn.modules.container.ModuleList):
164
+ to_out = self.to_out[0]
165
+ else:
166
+ to_out = self.to_out
167
+
168
+ h = self.heads
169
+ q = self.to_q(x)
170
+ is_cross = context is not None
171
+ context = context if is_cross else x
172
+ k = self.to_k(context)
173
+ v = self.to_v(context)
174
+ q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
175
+
176
+ sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
177
+
178
+ if mask is not None:
179
+ mask = rearrange(mask, 'b ... -> b (...)')
180
+ max_neg_value = -torch.finfo(sim.dtype).max
181
+ mask = repeat(mask, 'b j -> (b h) () j', h=h)
182
+ mask = mask[:, None, :].repeat(h, 1, 1)
183
+ sim.masked_fill_(~mask, max_neg_value)
184
+
185
+ attn = sim.softmax(dim=-1)
186
+ # the only difference
187
+ out = editor(
188
+ q, k, v, sim, attn, is_cross, place_in_unet,
189
+ self.heads, scale=self.scale)
190
+
191
+ return to_out(out)
192
+
193
+ return forward
194
+
195
+ def register_editor(net, count, place_in_unet):
196
+ for name, subnet in net.named_children():
197
+ if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
198
+ net.forward = ca_forward(net, place_in_unet)
199
+ return count + 1
200
+ elif hasattr(net, 'children'):
201
+ count = register_editor(subnet, count, place_in_unet)
202
+ return count
203
+
204
+ cross_att_count = 0
205
+ for net_name, net in model.model.diffusion_model.named_children():
206
+ if "input" in net_name:
207
+ cross_att_count += register_editor(net, 0, "input")
208
+ elif "middle" in net_name:
209
+ cross_att_count += register_editor(net, 0, "middle")
210
+ elif "output" in net_name:
211
+ cross_att_count += register_editor(net, 0, "output")
212
+ editor.num_att_layers = cross_att_count
utils/style_attn_control.py ADDED
@@ -0,0 +1,275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from re import U
17
+
18
+ import numpy as np
19
+
20
+ from einops import rearrange
21
+
22
+ from .masactrl_utils import AttentionBase
23
+
24
+ from torchvision.utils import save_image
25
+
26
+ import sys
27
+
28
+ import torch
29
+ import torch.nn.functional as F
30
+ from torch import nn
31
+ import torch.fft as fft
32
+
33
+ from einops import rearrange, repeat
34
+ from diffusers.utils import deprecate, logging
35
+ from diffusers.utils.import_utils import is_xformers_available
36
+ # from masactrl.masactrl import MutualSelfAttentionControl
37
+
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
41
+ if is_xformers_available():
42
+ import xformers
43
+ import xformers.ops
44
+ else:
45
+ xformers = None
46
+
47
+
48
+
49
+ class AttentionBase:
50
+ def __init__(self):
51
+ self.cur_step = 0
52
+ self.num_att_layers = -1
53
+ self.cur_att_layer = 0
54
+
55
+ def after_step(self):
56
+ pass
57
+
58
+ def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
59
+ out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
60
+ self.cur_att_layer += 1
61
+ if self.cur_att_layer == self.num_att_layers:
62
+ self.cur_att_layer = 0
63
+ self.cur_step += 1
64
+ # after step
65
+ self.after_step()
66
+ return out
67
+
68
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
69
+ out = torch.einsum('b i j, b j d -> b i d', attn, v)
70
+ out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
71
+ return out
72
+
73
+ def reset(self):
74
+ self.cur_step = 0
75
+ self.cur_att_layer = 0
76
+
77
+
78
+ class MaskPromptedStyleAttentionControl(AttentionBase):
79
+ def __init__(self, start_step=4, start_layer=10, style_attn_step=35, layer_idx=None, step_idx=None, total_steps=50, style_guidance=0.1,
80
+ only_masked_region=False, guidance=0.0,
81
+ style_mask=None, source_mask=None, de_bug=False):
82
+ """
83
+ MaskPromptedSAC
84
+ Args:
85
+ start_step: the step to start mutual self-attention control
86
+ start_layer: the layer to start mutual self-attention control
87
+ layer_idx: list of the layers to apply mutual self-attention control
88
+ step_idx: list the steps to apply mutual self-attention control
89
+ total_steps: the total number of steps
90
+ thres: the thereshold for mask thresholding
91
+ ref_token_idx: the token index list for cross-attention map aggregation
92
+ cur_token_idx: the token index list for cross-attention map aggregation
93
+ mask_save_dir: the path to save the mask image
94
+ """
95
+
96
+ super().__init__()
97
+ self.total_steps = total_steps
98
+ self.total_layers = 16
99
+ self.start_step = start_step
100
+ self.start_layer = start_layer
101
+ self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
102
+ self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
103
+ print("using MaskPromptStyleAttentionControl")
104
+ print("MaskedSAC at denoising steps: ", self.step_idx)
105
+ print("MaskedSAC at U-Net layers: ", self.layer_idx)
106
+
107
+ self.de_bug = de_bug
108
+ self.style_guidance = style_guidance
109
+ self.only_masked_region = only_masked_region
110
+ self.style_attn_step = style_attn_step
111
+ self.self_attns = []
112
+ self.cross_attns = []
113
+ self.guidance = guidance
114
+ self.style_mask = style_mask
115
+ self.source_mask = source_mask
116
+
117
+
118
+ def after_step(self):
119
+ self.self_attns = []
120
+ self.cross_attns = []
121
+
122
+ def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
123
+ B = q.shape[0] // num_heads
124
+ H = W = int(np.sqrt(q.shape[1]))
125
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
126
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
127
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
128
+
129
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
130
+
131
+ if q_mask is not None:
132
+ sim = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
133
+
134
+ if k_mask is not None:
135
+ sim = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
136
+
137
+ attn = sim.softmax(-1) if attn is None else attn
138
+
139
+ if len(attn) == 2 * len(v):
140
+ v = torch.cat([v] * 2)
141
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
142
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
143
+ return out
144
+
145
+ def attn_batch_fg_bg(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
146
+ B = q.shape[0] // num_heads
147
+ H = W = int(np.sqrt(q.shape[1]))
148
+ q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
149
+ k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
150
+ v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
151
+ sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
152
+ if q_mask is not None:
153
+ sim_fg = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
154
+ sim_bg = sim.masked_fill(q_mask.unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
155
+ if k_mask is not None:
156
+ sim_fg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
157
+ sim_bg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
158
+ sim = torch.cat([sim_fg, sim_bg])
159
+ attn = sim.softmax(-1)
160
+
161
+ if len(attn) == 2 * len(v):
162
+ v = torch.cat([v] * 2)
163
+ out = torch.einsum("h i j, h j d -> h i d", attn, v)
164
+ out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
165
+ return out
166
+
167
+ def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
168
+
169
+ """
170
+ Attention forward function
171
+ """
172
+
173
+ if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
174
+ return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
175
+
176
+ B = q.shape[0] // num_heads // 2
177
+ H = W = int(np.sqrt(q.shape[1]))
178
+
179
+ if self.style_mask is not None and self.source_mask is not None:
180
+ #mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (4, H, W)
181
+ heigh, width = self.style_mask.shape[-2:]
182
+ mask_style = self.style_mask# (H, W)
183
+ mask_source = self.source_mask# (H, W)
184
+ scale = int(np.sqrt(heigh * width / q.shape[1]))
185
+ # res = int(np.sqrt(q.shape[1]))
186
+ spatial_mask_source = F.interpolate(mask_source, (heigh//scale, width//scale)).reshape(-1, 1)
187
+ spatial_mask_style = F.interpolate(mask_style, (heigh//scale, width//scale)).reshape(-1, 1)
188
+
189
+ else:
190
+ spatial_mask_source=None
191
+ spatial_mask_style=None
192
+
193
+ if spatial_mask_style is None or spatial_mask_source is None:
194
+
195
+ out_s,out_c,out_t = self.style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
196
+
197
+ else:
198
+ if self.only_masked_region:
199
+ out_s,out_c,out_t = self.mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
200
+ else:
201
+ out_s,out_c,out_t = self.separate_mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
202
+
203
+ out = torch.cat([out_s,out_c,out_t],dim=0)
204
+ return out
205
+
206
+
207
+ def style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
208
+ if self.de_bug:
209
+ import pdb; pdb.set_trace()
210
+
211
+ qs, qc, qt = q.chunk(3)
212
+
213
+ out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
214
+ out_c = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
215
+
216
+ if self.cur_step < self.style_attn_step:
217
+ out_t = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
218
+ else:
219
+ out_t = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
220
+ if self.style_guidance>=0:
221
+ out_t = out_c + (out_t - out_c) * self.style_guidance
222
+ return out_s,out_c,out_t
223
+
224
+ def mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
225
+ qs, qc, qt = q.chunk(3)
226
+
227
+ out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
228
+ out_c = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], attn[num_heads: 2*num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
229
+ out_c_new = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
230
+
231
+ if self.de_bug:
232
+ import pdb; pdb.set_trace()
233
+
234
+ if self.cur_step < self.style_attn_step:
235
+ out_t = out_c #self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
236
+ else:
237
+ out_t_fg = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
238
+ out_c_fg = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
239
+ if self.style_guidance>=0:
240
+ out_t = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance
241
+
242
+ out_t = out_t * spatial_mask_source + out_c * (1 - spatial_mask_source)
243
+
244
+ if self.de_bug:
245
+ import pdb; pdb.set_trace()
246
+
247
+ # print(torch.sum(out_t* (1 - spatial_mask_source) - out_c * (1 - spatial_mask_source)))
248
+ return out_s,out_c,out_t
249
+
250
+ def separate_mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
251
+
252
+ if self.de_bug:
253
+ import pdb; pdb.set_trace()
254
+ # To prevent query confusion, render fg and bg according to mask.
255
+ qs, qc, qt = q.chunk(3)
256
+ out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
257
+ if self.cur_step < self.style_attn_step:
258
+
259
+ out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
260
+ out_c_fg,out_c_bg = out_c.chunk(2)
261
+ out_t = out_c_fg * spatial_mask_source + out_c_bg * (1 - spatial_mask_source)
262
+
263
+ else:
264
+ out_t = self.attn_batch_fg_bg(qt, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
265
+ out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
266
+ out_t_fg,out_t_bg = out_t.chunk(2)
267
+ out_c_fg,out_c_bg = out_c.chunk(2)
268
+ if self.style_guidance>=0:
269
+ out_t_fg = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance
270
+ out_t_bg = out_c_bg + (out_t_bg - out_c_bg) * self.style_guidance
271
+ out_t = out_t_fg * spatial_mask_source + out_t_bg * (1 - spatial_mask_source)
272
+
273
+ return out_s,out_t,out_t
274
+
275
+