Spaces:
Runtime error
Runtime error
initial add
Browse files- .gitignore +16 -0
- app.py +529 -0
- app.sh +7 -0
- pdiff/pdiff_pipeline.py +275 -0
- requirements.txt +6 -0
- style.css +3 -0
- utils/__init__.py +0 -0
- utils/__pycache__/__init__.cpython-310.pyc +0 -0
- utils/__pycache__/__init__.cpython-38.pyc +0 -0
- utils/__pycache__/__init__.cpython-39.pyc +0 -0
- utils/__pycache__/convert_from_ckpt.cpython-310.pyc +0 -0
- utils/__pycache__/convert_from_ckpt.cpython-38.pyc +0 -0
- utils/__pycache__/convert_from_ckpt.cpython-39.pyc +0 -0
- utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc +0 -0
- utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc +0 -0
- utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-39.pyc +0 -0
- utils/__pycache__/diffuser_utils.cpython-310.pyc +0 -0
- utils/__pycache__/diffuser_utils.cpython-38.pyc +0 -0
- utils/__pycache__/diffuser_utils.cpython-39.pyc +0 -0
- utils/__pycache__/free_lunch_utils.cpython-310.pyc +0 -0
- utils/__pycache__/free_lunch_utils.cpython-38.pyc +0 -0
- utils/__pycache__/free_lunch_utils.cpython-39.pyc +0 -0
- utils/__pycache__/masactrl_utils.cpython-310.pyc +0 -0
- utils/__pycache__/masactrl_utils.cpython-38.pyc +0 -0
- utils/__pycache__/masactrl_utils.cpython-39.pyc +0 -0
- utils/__pycache__/style_attn_control.cpython-310.pyc +0 -0
- utils/__pycache__/style_attn_control.cpython-38.pyc +0 -0
- utils/__pycache__/style_attn_control.cpython-39.pyc +0 -0
- utils/convert_from_ckpt.py +959 -0
- utils/convert_lora_safetensor_to_diffusers.py +154 -0
- utils/diffuser_utils.py +275 -0
- utils/free_lunch_utils.py +334 -0
- utils/masactrl_utils.py +212 -0
- utils/style_attn_control.py +275 -0
.gitignore
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
./data
|
2 |
+
./results
|
3 |
+
./results_ablation
|
4 |
+
./workdir
|
5 |
+
./row_results
|
6 |
+
./new_res
|
7 |
+
./cop
|
8 |
+
examper
|
9 |
+
results
|
10 |
+
data
|
11 |
+
results_ablation
|
12 |
+
row_results
|
13 |
+
new_res
|
14 |
+
cop
|
15 |
+
./samples
|
16 |
+
samples
|
app.py
ADDED
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import re
|
4 |
+
from turtle import width
|
5 |
+
import torch
|
6 |
+
import random
|
7 |
+
import numpy as np
|
8 |
+
import gradio as gr
|
9 |
+
from glob import glob
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from datetime import datetime
|
12 |
+
from safetensors import safe_open
|
13 |
+
|
14 |
+
from diffusers import AutoencoderKL,UNet2DConditionModel,StableDiffusionPipeline
|
15 |
+
from diffusers import DDIMScheduler, EulerDiscreteScheduler, PNDMScheduler
|
16 |
+
from diffusers.utils.import_utils import is_xformers_available
|
17 |
+
from transformers import CLIPTextModel, CLIPTokenizer
|
18 |
+
from utils.convert_from_ckpt import convert_ldm_unet_checkpoint, convert_ldm_clip_checkpoint, convert_ldm_vae_checkpoint
|
19 |
+
from utils.convert_lora_safetensor_to_diffusers import convert_lora
|
20 |
+
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from PIL import Image
|
23 |
+
|
24 |
+
from utils.diffuser_utils import MasaCtrlPipeline
|
25 |
+
from utils.masactrl_utils import (AttentionBase,
|
26 |
+
regiter_attention_editor_diffusers)
|
27 |
+
from utils.free_lunch_utils import register_upblock2d,register_crossattn_upblock2d,register_free_upblock2d, register_free_crossattn_upblock2d
|
28 |
+
|
29 |
+
from utils.style_attn_control import MaskPromptedStyleAttentionControl
|
30 |
+
from torchvision.utils import save_image
|
31 |
+
from diffusers.models.attention_processor import AttnProcessor2_0
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
|
36 |
+
css = """
|
37 |
+
.toolbutton {
|
38 |
+
margin-buttom: 0em 0em 0em 0em;
|
39 |
+
max-width: 2.5em;
|
40 |
+
min-width: 2.5em !important;
|
41 |
+
height: 2.5em;
|
42 |
+
}
|
43 |
+
"""
|
44 |
+
|
45 |
+
class GlobalText:
|
46 |
+
def __init__(self):
|
47 |
+
|
48 |
+
# config dirs
|
49 |
+
self.basedir = os.getcwd()
|
50 |
+
self.stable_diffusion_dir = os.path.join(self.basedir, "models", "StableDiffusion")
|
51 |
+
self.personalized_model_dir = '/home/jin.liu/liujin/webui/stable-diffusion-webui/models/Stable-diffusion'
|
52 |
+
self.lora_model_dir = '/home/jin.liu/liujin/webui/stable-diffusion-webui/models/Lora'
|
53 |
+
self.savedir = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
|
54 |
+
self.savedir_sample = os.path.join(self.savedir, "sample")
|
55 |
+
|
56 |
+
self.savedir_mask = os.path.join(self.savedir, "mask")
|
57 |
+
|
58 |
+
self.stable_diffusion_list = ["/home/jin.liu/liujin/cache/hub/models--runwayml--stable-diffusion-v1-5/snapshots/1d0c4ebf6ff58a5caecab40fa1406526bca4b5b9",
|
59 |
+
"runwayml/stable-diffusion-v1-5",
|
60 |
+
"stabilityai/stable-diffusion-2-1"]
|
61 |
+
self.personalized_model_list = []
|
62 |
+
self.lora_model_list = []
|
63 |
+
|
64 |
+
# config models
|
65 |
+
self.tokenizer = None
|
66 |
+
self.text_encoder = None
|
67 |
+
self.vae = None
|
68 |
+
self.unet = None
|
69 |
+
self.pipeline = None
|
70 |
+
self.lora_loaded = None
|
71 |
+
self.personal_model_loaded = None
|
72 |
+
self.lora_model_state_dict = {}
|
73 |
+
self.device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
74 |
+
self.refresh_stable_diffusion()
|
75 |
+
self.refresh_personalized_model()
|
76 |
+
|
77 |
+
self.reset_start_code()
|
78 |
+
def load_base_pipeline(self, model_path):
|
79 |
+
print(f'loading {model_path} model')
|
80 |
+
scheduler = DDIMScheduler.from_pretrained(model_path,subfolder="scheduler")
|
81 |
+
self.pipeline = MasaCtrlPipeline.from_pretrained(model_path,
|
82 |
+
scheduler=scheduler).to(self.device)
|
83 |
+
|
84 |
+
def refresh_stable_diffusion(self):
|
85 |
+
|
86 |
+
self.load_base_pipeline(self.stable_diffusion_list[0])
|
87 |
+
self.lora_loaded = None
|
88 |
+
self.personal_model_loaded = None
|
89 |
+
return self.stable_diffusion_list[0]
|
90 |
+
|
91 |
+
def refresh_personalized_model(self):
|
92 |
+
personalized_model_list = glob(os.path.join(self.personalized_model_dir, "**/*.safetensors"), recursive=True)
|
93 |
+
self.personalized_model_list = {os.path.basename(file): file for file in personalized_model_list}
|
94 |
+
|
95 |
+
lora_model_list = glob(os.path.join(self.lora_model_dir, "**/*.safetensors"), recursive=True)
|
96 |
+
self.lora_model_list = {os.path.basename(file): file for file in lora_model_list}
|
97 |
+
|
98 |
+
def update_stable_diffusion(self, stable_diffusion_dropdown):
|
99 |
+
|
100 |
+
self.load_base_pipeline(stable_diffusion_dropdown)
|
101 |
+
self.lora_loaded = None
|
102 |
+
self.personal_model_loaded = None
|
103 |
+
return gr.Dropdown.update()
|
104 |
+
|
105 |
+
def update_base_model(self, base_model_dropdown):
|
106 |
+
if self.pipeline is None:
|
107 |
+
gr.Info(f"Please select a pretrained model path.")
|
108 |
+
return None
|
109 |
+
else:
|
110 |
+
base_model = self.personalized_model_list[base_model_dropdown]
|
111 |
+
mid_model = StableDiffusionPipeline.from_single_file(base_model)
|
112 |
+
self.pipeline.vae = mid_model.vae
|
113 |
+
self.pipeline.unet = mid_model.unet
|
114 |
+
self.pipeline.text_encoder = mid_model.text_encoder
|
115 |
+
self.pipeline.to(self.device)
|
116 |
+
self.personal_model_loaded = base_model_dropdown.split('.')[0]
|
117 |
+
print(f'load {base_model_dropdown} model success!')
|
118 |
+
|
119 |
+
return gr.Dropdown()
|
120 |
+
|
121 |
+
|
122 |
+
def update_lora_model(self, lora_model_dropdown,lora_alpha_slider):
|
123 |
+
|
124 |
+
if self.pipeline is None:
|
125 |
+
gr.Info(f"Please select a pretrained model path.")
|
126 |
+
return None
|
127 |
+
else:
|
128 |
+
if lora_model_dropdown == "none":
|
129 |
+
self.pipeline.unfuse_lora()
|
130 |
+
self.pipeline.unload_lora_weights()
|
131 |
+
self.lora_loaded = None
|
132 |
+
# self.personal_model_loaded = None
|
133 |
+
print("Restore lora.")
|
134 |
+
else:
|
135 |
+
|
136 |
+
lora_model_path = self.lora_model_list[lora_model_dropdown]#os.path.join(self.lora_model_dir, lora_model_dropdown)
|
137 |
+
# self.lora_model_state_dict = {}
|
138 |
+
# if lora_model_dropdown == "none": pass
|
139 |
+
# else:
|
140 |
+
# with safe_open(lora_model_dropdown, framework="pt", device="cpu") as f:
|
141 |
+
# for key in f.keys():
|
142 |
+
# self.lora_model_state_dict[key] = f.get_tensor(key)
|
143 |
+
# convert_lora(self.pipeline, self.lora_model_state_dict, alpha=lora_alpha_slider)
|
144 |
+
self.pipeline.unfuse_lora()
|
145 |
+
self.pipeline.unload_lora_weights()
|
146 |
+
self.pipeline.load_lora_weights(lora_model_path)
|
147 |
+
self.pipeline.fuse_lora(lora_alpha_slider)
|
148 |
+
self.lora_loaded = lora_model_dropdown.split('.')[0]
|
149 |
+
print(f'load {lora_model_dropdown} model success!')
|
150 |
+
return gr.Dropdown()
|
151 |
+
|
152 |
+
def generate(self, source, style, source_mask, style_mask,
|
153 |
+
start_step, start_layer, Style_attn_step,
|
154 |
+
Method, Style_Guidance, ddim_steps, scale, seed, de_bug,
|
155 |
+
target_prompt, negative_prompt_textbox,
|
156 |
+
inter_latents,
|
157 |
+
freeu, b1, b2, s1, s2,
|
158 |
+
width_slider,height_slider,
|
159 |
+
):
|
160 |
+
os.makedirs(self.savedir, exist_ok=True)
|
161 |
+
os.makedirs(self.savedir_sample, exist_ok=True)
|
162 |
+
os.makedirs(self.savedir_mask, exist_ok=True)
|
163 |
+
model = self.pipeline
|
164 |
+
|
165 |
+
if seed != -1 and seed != "": torch.manual_seed(int(seed))
|
166 |
+
else: torch.seed()
|
167 |
+
seed = torch.initial_seed()
|
168 |
+
sample_count = len(os.listdir(self.savedir_sample))
|
169 |
+
os.makedirs(os.path.join(self.savedir_mask, f"results_{sample_count}"), exist_ok=True)
|
170 |
+
|
171 |
+
# ref_prompt = [source_prompt, target_prompt]
|
172 |
+
# prompts = ref_prompt+['']
|
173 |
+
ref_prompt = [target_prompt, target_prompt]
|
174 |
+
prompts = ref_prompt+[target_prompt]
|
175 |
+
source_image,style_image,source_mask,style_mask = load_mask_images(source,style,source_mask,style_mask,self.device,width_slider,height_slider,out_dir=os.path.join(self.savedir_mask, f"results_{sample_count}"))
|
176 |
+
|
177 |
+
|
178 |
+
# global START_CODE, LATENTS_LIST
|
179 |
+
|
180 |
+
with torch.no_grad():
|
181 |
+
#import pdb;pdb.set_trace()
|
182 |
+
|
183 |
+
#prev_source
|
184 |
+
if self.start_code is None and self.latents_list is None:
|
185 |
+
content_style = torch.cat([style_image, source_image], dim=0)
|
186 |
+
editor = AttentionBase()
|
187 |
+
regiter_attention_editor_diffusers(model, editor)
|
188 |
+
st_code, latents_list = model.invert(content_style,
|
189 |
+
ref_prompt,
|
190 |
+
guidance_scale=scale,
|
191 |
+
num_inference_steps=ddim_steps,
|
192 |
+
return_intermediates=True)
|
193 |
+
start_code = torch.cat([st_code, st_code[1:]], dim=0)
|
194 |
+
self.start_code = start_code
|
195 |
+
self.latents_list = latents_list
|
196 |
+
else:
|
197 |
+
start_code = self.start_code
|
198 |
+
latents_list = self.latents_list
|
199 |
+
print('------------------------------------------ Use previous latents ------------------------------------------ ')
|
200 |
+
|
201 |
+
#["Without mask", "Only masked region", "Seperate Background Foreground"]
|
202 |
+
|
203 |
+
if Method == "Without mask":
|
204 |
+
style_mask = None
|
205 |
+
source_mask = None
|
206 |
+
only_masked_region = False
|
207 |
+
elif Method == "Only masked region":
|
208 |
+
assert style_mask is not None and source_mask is not None
|
209 |
+
only_masked_region = True
|
210 |
+
else:
|
211 |
+
assert style_mask is not None and source_mask is not None
|
212 |
+
only_masked_region = False
|
213 |
+
|
214 |
+
controller = MaskPromptedStyleAttentionControl(start_step, start_layer,
|
215 |
+
style_attn_step=Style_attn_step,
|
216 |
+
style_guidance=Style_Guidance,
|
217 |
+
style_mask=style_mask,
|
218 |
+
source_mask=source_mask,
|
219 |
+
only_masked_region=only_masked_region,
|
220 |
+
guidance=scale,
|
221 |
+
de_bug=de_bug,
|
222 |
+
)
|
223 |
+
if freeu:
|
224 |
+
print(f'++++++++++++++++++ Run with FreeU {b1}_{b2}_{s1}_{s2} ++++++++++++++++')
|
225 |
+
if Method != "Without mask":
|
226 |
+
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
|
227 |
+
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=source_mask)
|
228 |
+
else:
|
229 |
+
register_free_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None)
|
230 |
+
register_free_crossattn_upblock2d(model, b1=b1, b2=b2, s1=s1, s2=s1,source_mask=None)
|
231 |
+
|
232 |
+
else:
|
233 |
+
print(f'++++++++++++++++++ Run without FreeU ++++++++++++++++')
|
234 |
+
register_upblock2d(model)
|
235 |
+
register_crossattn_upblock2d(model)
|
236 |
+
regiter_attention_editor_diffusers(model, controller)
|
237 |
+
|
238 |
+
regiter_attention_editor_diffusers(model, controller)
|
239 |
+
|
240 |
+
# inference the synthesized image
|
241 |
+
generate_image= model(prompts,
|
242 |
+
width=width_slider,
|
243 |
+
height=height_slider,
|
244 |
+
latents=start_code,
|
245 |
+
guidance_scale=scale,
|
246 |
+
num_inference_steps=ddim_steps,
|
247 |
+
ref_intermediate_latents=latents_list if inter_latents else None,
|
248 |
+
neg_prompt=negative_prompt_textbox,
|
249 |
+
return_intermediates=False,)
|
250 |
+
|
251 |
+
# os.makedirs(os.path.join(output_dir, f"results_{sample_count}"))
|
252 |
+
save_file_name = f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}.jpg"
|
253 |
+
if self.lora_loaded != None:
|
254 |
+
save_file_name = f"lora_{self.lora_loaded}_" + save_file_name
|
255 |
+
if self.personal_model_loaded != None:
|
256 |
+
save_file_name = f"personal_{self.personal_model_loaded}_" + save_file_name
|
257 |
+
#f"results_{sample_count}_step{start_step}_layer{start_layer}SG{Style_Guidance}_style_attn_step{Style_attn_step}_lora_{self.lora_loaded}.jpg"
|
258 |
+
save_file_path = os.path.join(self.savedir_sample, save_file_name)
|
259 |
+
#save_file_name = os.path.join(output_dir, f"results_style_{style_name}", f"{content_name}.jpg")
|
260 |
+
|
261 |
+
save_image(torch.cat([source_image/2 + 0.5, style_image/2 + 0.5, generate_image[2:]], dim=0), save_file_path, nrow=3, padding=0)
|
262 |
+
|
263 |
+
|
264 |
+
|
265 |
+
# global OUTPUT_RESULT
|
266 |
+
# OUTPUT_RESULT = save_file_name
|
267 |
+
|
268 |
+
generate_image = generate_image.cpu().permute(0, 2, 3, 1).numpy()
|
269 |
+
#save_gif(latents_list, os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif'))
|
270 |
+
# import pdb;pdb.set_trace()
|
271 |
+
#gif_dir = os.path.join(output_dir, f"results_{sample_count}",'output_latents_list.gif')
|
272 |
+
|
273 |
+
return [
|
274 |
+
generate_image[0],
|
275 |
+
generate_image[1],
|
276 |
+
generate_image[2],
|
277 |
+
]
|
278 |
+
|
279 |
+
def reset_start_code(self,):
|
280 |
+
self.start_code = None
|
281 |
+
self.latents_list = None
|
282 |
+
|
283 |
+
global_text = GlobalText()
|
284 |
+
|
285 |
+
|
286 |
+
def load_mask_images(source,style,source_mask,style_mask,device,width,height,out_dir=None):
|
287 |
+
# invert the image into noise map
|
288 |
+
if isinstance(source['image'], np.ndarray):
|
289 |
+
source_image = torch.from_numpy(source['image']).to(device) / 127.5 - 1.
|
290 |
+
else:
|
291 |
+
source_image = torch.from_numpy(np.array(source['image'])).to(device) / 127.5 - 1.
|
292 |
+
source_image = source_image.unsqueeze(0).permute(0, 3, 1, 2)
|
293 |
+
|
294 |
+
source_image = F.interpolate(source_image, (height,width ))
|
295 |
+
|
296 |
+
if out_dir is not None and source_mask is None:
|
297 |
+
|
298 |
+
source['mask'].save(os.path.join(out_dir,'source_mask.jpg'))
|
299 |
+
else:
|
300 |
+
Image.fromarray(source_mask).save(os.path.join(out_dir,'source_mask.jpg'))
|
301 |
+
if out_dir is not None and style_mask is None:
|
302 |
+
|
303 |
+
style['mask'].save(os.path.join(out_dir,'style_mask.jpg'))
|
304 |
+
else:
|
305 |
+
Image.fromarray(style_mask).save(os.path.join(out_dir,'style_mask.jpg'))
|
306 |
+
# save source['mask']
|
307 |
+
# import pdb;pdb.set_trace()
|
308 |
+
source_mask = torch.from_numpy(np.array(source['mask']) if source_mask is None else source_mask).to(device) / 255.
|
309 |
+
source_mask = source_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
|
310 |
+
source_mask = F.interpolate(source_mask, (height//8,width//8))
|
311 |
+
|
312 |
+
if isinstance(source['image'], np.ndarray):
|
313 |
+
style_image = torch.from_numpy(style['image']).to(device) / 127.5 - 1.
|
314 |
+
else:
|
315 |
+
style_image = torch.from_numpy(np.array(style['image'])).to(device) / 127.5 - 1.
|
316 |
+
style_image = style_image.unsqueeze(0).permute(0, 3, 1, 2)
|
317 |
+
style_image = F.interpolate(style_image, (height,width))
|
318 |
+
|
319 |
+
style_mask = torch.from_numpy(np.array(style['mask']) if style_mask is None else style_mask ).to(device) / 255.
|
320 |
+
style_mask = style_mask.unsqueeze(0).permute(0, 3, 1, 2)[:,:1]
|
321 |
+
style_mask = F.interpolate(style_mask, (height//8,width//8))
|
322 |
+
|
323 |
+
|
324 |
+
return source_image,style_image,source_mask,style_mask
|
325 |
+
|
326 |
+
|
327 |
+
def ui():
|
328 |
+
with gr.Blocks(css=css) as demo:
|
329 |
+
gr.Markdown(
|
330 |
+
"""
|
331 |
+
# [Portrait Diffusion: Training-free Face Stylization with Chain-of-Painting](https://arxiv.org/abs/00000)
|
332 |
+
Jin Liu, Huaibo Huang, Chao Jin, Ran He* (*Corresponding Author)<br>
|
333 |
+
[Arxiv Report](https://arxiv.org/abs/0000) | [Project Page](https://www.github.io/) | [Github](https://github.com/)
|
334 |
+
"""
|
335 |
+
)
|
336 |
+
with gr.Column(variant="panel"):
|
337 |
+
gr.Markdown(
|
338 |
+
"""
|
339 |
+
### 1. Select a pretrained model.
|
340 |
+
"""
|
341 |
+
)
|
342 |
+
with gr.Row():
|
343 |
+
stable_diffusion_dropdown = gr.Dropdown(
|
344 |
+
label="Pretrained Model Path",
|
345 |
+
choices=global_text.stable_diffusion_list,
|
346 |
+
interactive=True,
|
347 |
+
allow_custom_value=True
|
348 |
+
)
|
349 |
+
stable_diffusion_dropdown.change(fn=global_text.update_stable_diffusion, inputs=[stable_diffusion_dropdown], outputs=[stable_diffusion_dropdown])
|
350 |
+
|
351 |
+
stable_diffusion_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
352 |
+
def update_stable_diffusion():
|
353 |
+
global_text.refresh_stable_diffusion()
|
354 |
+
|
355 |
+
stable_diffusion_refresh_button.click(fn=update_stable_diffusion, inputs=[], outputs=[])
|
356 |
+
|
357 |
+
base_model_dropdown = gr.Dropdown(
|
358 |
+
label="Select a ckpt model (optional)",
|
359 |
+
choices=sorted(list(global_text.personalized_model_list.keys())),
|
360 |
+
interactive=True,
|
361 |
+
allow_custom_value=True,
|
362 |
+
)
|
363 |
+
base_model_dropdown.change(fn=global_text.update_base_model, inputs=[base_model_dropdown], outputs=[base_model_dropdown])
|
364 |
+
|
365 |
+
lora_model_dropdown = gr.Dropdown(
|
366 |
+
label="Select a LoRA model (optional)",
|
367 |
+
choices=["none"] + sorted(list(global_text.lora_model_list.keys())),
|
368 |
+
value="none",
|
369 |
+
interactive=True,
|
370 |
+
allow_custom_value=True,
|
371 |
+
)
|
372 |
+
lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.8, minimum=0, maximum=2, interactive=True)
|
373 |
+
lora_model_dropdown.change(fn=global_text.update_lora_model, inputs=[lora_model_dropdown,lora_alpha_slider], outputs=[lora_model_dropdown])
|
374 |
+
|
375 |
+
|
376 |
+
|
377 |
+
personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
|
378 |
+
|
379 |
+
def update_personalized_model():
|
380 |
+
global_text.refresh_personalized_model()
|
381 |
+
return [
|
382 |
+
gr.Dropdown(choices=sorted(list(global_text.personalized_model_list.keys()))),
|
383 |
+
gr.Dropdown(choices=["none"] + sorted(list(global_text.lora_model_list.keys())))
|
384 |
+
]
|
385 |
+
personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
|
386 |
+
|
387 |
+
|
388 |
+
with gr.Column(variant="panel"):
|
389 |
+
gr.Markdown(
|
390 |
+
"""
|
391 |
+
### 2. Configs for PortraitDiff.
|
392 |
+
"""
|
393 |
+
)
|
394 |
+
with gr.Tab("Configs"):
|
395 |
+
|
396 |
+
with gr.Row():
|
397 |
+
source_image = gr.Image(label="Source Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512)
|
398 |
+
style_image = gr.Image(label="Style Image", elem_id="img2maskimg", source="upload", interactive=True, type="pil", tool="sketch", image_mode="RGB", height=512)
|
399 |
+
with gr.Row():
|
400 |
+
prompt_textbox = gr.Textbox(label="Prompt", value='head', lines=1)
|
401 |
+
negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=1)
|
402 |
+
# output_dir = gr.Textbox(label="output_dir", value='./results/')
|
403 |
+
|
404 |
+
with gr.Row().style(equal_height=False):
|
405 |
+
with gr.Column():
|
406 |
+
width_slider = gr.Slider(label="Width", value=512, minimum=256, maximum=1024, step=64)
|
407 |
+
height_slider = gr.Slider(label="Height", value=512, minimum=256, maximum=1024, step=64)
|
408 |
+
Method = gr.Dropdown(
|
409 |
+
["Without mask", "Only masked region", "Seperate Background Foreground"],
|
410 |
+
value="Without mask",
|
411 |
+
label="Mask", info="Select how to use masks")
|
412 |
+
with gr.Tab('Base Configs'):
|
413 |
+
with gr.Row():
|
414 |
+
# sampler_dropdown = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
|
415 |
+
ddim_steps = gr.Slider(label="DDIM Steps", value=50, minimum=10, maximum=100, step=1)
|
416 |
+
|
417 |
+
Style_attn_step = gr.Slider(label="Step of Style Attention Control",
|
418 |
+
minimum=0,
|
419 |
+
maximum=50,
|
420 |
+
value=35,
|
421 |
+
step=1)
|
422 |
+
start_step = gr.Slider(label="Step of Attention Control",
|
423 |
+
minimum=0,
|
424 |
+
maximum=150,
|
425 |
+
value=0,
|
426 |
+
step=1)
|
427 |
+
start_layer = gr.Slider(label="Layer of Style Attention Control",
|
428 |
+
minimum=0,
|
429 |
+
maximum=16,
|
430 |
+
value=10,
|
431 |
+
step=1)
|
432 |
+
Style_Guidance = gr.Slider(label="Style Guidance Scale",
|
433 |
+
minimum=0,
|
434 |
+
maximum=4,
|
435 |
+
value=1.2,
|
436 |
+
step=0.05)
|
437 |
+
cfg_scale_slider = gr.Slider(label="CFG Scale", value=0, minimum=0, maximum=20)
|
438 |
+
|
439 |
+
|
440 |
+
with gr.Tab('FreeU'):
|
441 |
+
with gr.Row():
|
442 |
+
freeu = gr.Checkbox(label="Free Upblock", value=False)
|
443 |
+
de_bug = gr.Checkbox(value=False,label='DeBug')
|
444 |
+
inter_latents = gr.Checkbox(value=True,label='Use intermediate latents')
|
445 |
+
with gr.Row():
|
446 |
+
b1 = gr.Slider(label='b1:',
|
447 |
+
minimum=-1,
|
448 |
+
maximum=2,
|
449 |
+
step=0.01,
|
450 |
+
value=1.3)
|
451 |
+
b2 = gr.Slider(label='b2:',
|
452 |
+
minimum=-1,
|
453 |
+
maximum=2,
|
454 |
+
step=0.01,
|
455 |
+
value=1.5)
|
456 |
+
with gr.Row():
|
457 |
+
s1 = gr.Slider(label='s1: ',
|
458 |
+
minimum=0,
|
459 |
+
maximum=2,
|
460 |
+
step=0.1,
|
461 |
+
value=1.0)
|
462 |
+
s2 = gr.Slider(label='s2:',
|
463 |
+
minimum=0,
|
464 |
+
maximum=2,
|
465 |
+
step=0.1,
|
466 |
+
value=1.0)
|
467 |
+
with gr.Row():
|
468 |
+
seed_textbox = gr.Textbox(label="Seed", value=-1)
|
469 |
+
seed_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
|
470 |
+
seed_button.click(fn=lambda: random.randint(1, 1e8), inputs=[], outputs=[seed_textbox])
|
471 |
+
|
472 |
+
with gr.Column():
|
473 |
+
generate_button = gr.Button(value="Generate", variant='primary')
|
474 |
+
|
475 |
+
generate_image = gr.Image(label="Image with PortraitDiff", interactive=False, type='numpy', height=512,)
|
476 |
+
|
477 |
+
with gr.Row():
|
478 |
+
recons_content = gr.Image(label="reconstructed content", type="pil", image_mode="RGB", height=256)
|
479 |
+
recons_style = gr.Image(label="reconstructed style", type="pil", image_mode="RGB", height=256)
|
480 |
+
|
481 |
+
with gr.Tab("SAM"):
|
482 |
+
with gr.Column():
|
483 |
+
add_or_remove = gr.Radio(["Add Mask", "Remove Area"], value="Add Mask", label="Point_label (foreground/background)")
|
484 |
+
with gr.Row():
|
485 |
+
sam_source_btn = gr.Button(value="SAM Source")
|
486 |
+
send_source_btn = gr.Button(value="Send Source")
|
487 |
+
|
488 |
+
sam_style_btn = gr.Button(value="SAM Style")
|
489 |
+
send_style_btn = gr.Button(value="Send Style")
|
490 |
+
with gr.Row():
|
491 |
+
source_image_sam = gr.Image(label="Source Image SAM", elem_id="SourceimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
|
492 |
+
style_image_sam = gr.Image(label="Style Image SAM", elem_id="StyleimgSAM", source="upload", interactive=True, type="pil", image_mode="RGB", height=512)
|
493 |
+
|
494 |
+
with gr.Row():
|
495 |
+
source_image_with_points = gr.Image(label="source Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
|
496 |
+
source_mask = gr.Image(label="Source Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
|
497 |
+
|
498 |
+
style_image_with_points = gr.Image(label="Style Image with points", elem_id="style_image_with_points", type="pil", image_mode="RGB", height=256)
|
499 |
+
style_mask = gr.Image(label="Style Mask", elem_id="img2maskimg", source="upload", interactive=True, type="numpy", image_mode="RGB", height=256)
|
500 |
+
|
501 |
+
gr.Examples(
|
502 |
+
[[os.path.join(os.path.dirname(__file__), "gradio_app/images/content/1.jpg"),
|
503 |
+
os.path.join(os.path.dirname(__file__), "gradio_app/images/style/1.jpg")],
|
504 |
+
|
505 |
+
],
|
506 |
+
[source_image, style_image]
|
507 |
+
)
|
508 |
+
inputs = [
|
509 |
+
source_image, style_image, source_mask, style_mask,
|
510 |
+
start_step, start_layer, Style_attn_step,
|
511 |
+
Method, Style_Guidance,ddim_steps, cfg_scale_slider, seed_textbox, de_bug,
|
512 |
+
prompt_textbox, negative_prompt_textbox, inter_latents,
|
513 |
+
freeu, b1, b2, s1, s2,
|
514 |
+
width_slider,height_slider,
|
515 |
+
]
|
516 |
+
|
517 |
+
generate_button.click(
|
518 |
+
fn=global_text.generate,
|
519 |
+
inputs=inputs,
|
520 |
+
outputs=[recons_style,recons_content,generate_image]
|
521 |
+
)
|
522 |
+
source_image.upload(global_text.reset_start_code, inputs=[], outputs=[])
|
523 |
+
style_image.upload(global_text.reset_start_code, inputs=[], outputs=[])
|
524 |
+
ddim_steps.change(fn=global_text.reset_start_code, inputs=[], outputs=[])
|
525 |
+
return demo
|
526 |
+
|
527 |
+
if __name__ == "__main__":
|
528 |
+
demo = ui()
|
529 |
+
demo.launch(server_name="172.18.32.44")
|
app.sh
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
export CUDA_VISIBLE_DEVICES=$1
|
4 |
+
|
5 |
+
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
6 |
+
# export CUDA_VISIBLE_DEVICES=5
|
7 |
+
python gapp.py
|
pdiff/pdiff_pipeline.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Util functions based on Diffuser framework.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision.utils import save_image
|
15 |
+
from torchvision.io import read_image
|
16 |
+
|
17 |
+
from diffusers import StableDiffusionPipeline
|
18 |
+
|
19 |
+
from pytorch_lightning import seed_everything
|
20 |
+
|
21 |
+
|
22 |
+
class MasaCtrlPipeline(StableDiffusionPipeline):
|
23 |
+
|
24 |
+
def next_step(
|
25 |
+
self,
|
26 |
+
model_output: torch.FloatTensor,
|
27 |
+
timestep: int,
|
28 |
+
x: torch.FloatTensor,
|
29 |
+
eta=0.,
|
30 |
+
verbose=False
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Inverse sampling for DDIM Inversion
|
34 |
+
"""
|
35 |
+
if verbose:
|
36 |
+
print("timestep: ", timestep)
|
37 |
+
next_step = timestep
|
38 |
+
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
|
39 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
40 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
41 |
+
beta_prod_t = 1 - alpha_prod_t
|
42 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
43 |
+
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
44 |
+
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
45 |
+
return x_next, pred_x0
|
46 |
+
|
47 |
+
def step(
|
48 |
+
self,
|
49 |
+
model_output: torch.FloatTensor,
|
50 |
+
timestep: int,
|
51 |
+
x: torch.FloatTensor,
|
52 |
+
eta: float=0.0,
|
53 |
+
verbose=False,
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
predict the sampe the next step in the denoise process.
|
57 |
+
"""
|
58 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
59 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
60 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
61 |
+
beta_prod_t = 1 - alpha_prod_t
|
62 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
63 |
+
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
64 |
+
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
65 |
+
return x_prev, pred_x0
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def image2latent(self, image):
|
69 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
70 |
+
if type(image) is Image:
|
71 |
+
image = np.array(image)
|
72 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
73 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
|
74 |
+
# input image density range [-1, 1]
|
75 |
+
latents = self.vae.encode(image)['latent_dist'].mean
|
76 |
+
latents = latents * 0.18215
|
77 |
+
return latents
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def latent2image(self, latents, return_type='np'):
|
81 |
+
latents = 1 / 0.18215 * latents.detach()
|
82 |
+
image = self.vae.decode(latents)['sample']
|
83 |
+
if return_type == 'np':
|
84 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
85 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
86 |
+
image = (image * 255).astype(np.uint8)
|
87 |
+
elif return_type == "pt":
|
88 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
def latent2image_grad(self, latents):
|
93 |
+
latents = 1 / 0.18215 * latents
|
94 |
+
image = self.vae.decode(latents)['sample']
|
95 |
+
|
96 |
+
return image # range [-1, 1]
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def __call__(
|
100 |
+
self,
|
101 |
+
prompt,
|
102 |
+
batch_size=1,
|
103 |
+
height=512,
|
104 |
+
width=512,
|
105 |
+
num_inference_steps=50,
|
106 |
+
guidance_scale=7.5,
|
107 |
+
eta=0.0,
|
108 |
+
latents=None,
|
109 |
+
unconditioning=None,
|
110 |
+
neg_prompt=None,
|
111 |
+
ref_intermediate_latents=None,
|
112 |
+
return_intermediates=False,
|
113 |
+
**kwds):
|
114 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
115 |
+
if isinstance(prompt, list):
|
116 |
+
batch_size = len(prompt)
|
117 |
+
elif isinstance(prompt, str):
|
118 |
+
if batch_size > 1:
|
119 |
+
prompt = [prompt] * batch_size
|
120 |
+
|
121 |
+
# text embeddings
|
122 |
+
text_input = self.tokenizer(
|
123 |
+
prompt,
|
124 |
+
padding="max_length",
|
125 |
+
max_length=77,
|
126 |
+
return_tensors="pt"
|
127 |
+
)
|
128 |
+
|
129 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
130 |
+
print("input text embeddings :", text_embeddings.shape)
|
131 |
+
if kwds.get("dir"):
|
132 |
+
dir = text_embeddings[-2] - text_embeddings[-1]
|
133 |
+
u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
|
134 |
+
text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
|
135 |
+
print(u.shape)
|
136 |
+
print(v.shape)
|
137 |
+
|
138 |
+
# define initial latents
|
139 |
+
latents_shape = (batch_size, self.unet.config.in_channels, height//8, width//8)
|
140 |
+
if latents is None:
|
141 |
+
latents = torch.randn(latents_shape, device=DEVICE)
|
142 |
+
else:
|
143 |
+
assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
|
144 |
+
|
145 |
+
# unconditional embedding for classifier free guidance
|
146 |
+
if guidance_scale > 1.:
|
147 |
+
max_length = text_input.input_ids.shape[-1]
|
148 |
+
if neg_prompt:
|
149 |
+
uc_text = neg_prompt
|
150 |
+
else:
|
151 |
+
uc_text = ""
|
152 |
+
# uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
|
153 |
+
unconditional_input = self.tokenizer(
|
154 |
+
[uc_text] * batch_size,
|
155 |
+
padding="max_length",
|
156 |
+
max_length=77,
|
157 |
+
return_tensors="pt"
|
158 |
+
)
|
159 |
+
# unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
|
160 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
161 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
|
162 |
+
|
163 |
+
print("latents shape: ", latents.shape)
|
164 |
+
# iterative sampling
|
165 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
166 |
+
# print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
167 |
+
latents_list = [latents]
|
168 |
+
pred_x0_list = [latents]
|
169 |
+
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
|
170 |
+
if ref_intermediate_latents is not None:
|
171 |
+
# note that the batch_size >= 2
|
172 |
+
latents_ref = ref_intermediate_latents[-1 - i]
|
173 |
+
_, latents_cur = latents.chunk(2)
|
174 |
+
latents = torch.cat([latents_ref, latents_cur])
|
175 |
+
|
176 |
+
if guidance_scale > 1.:
|
177 |
+
model_inputs = torch.cat([latents] * 2)
|
178 |
+
else:
|
179 |
+
model_inputs = latents
|
180 |
+
if unconditioning is not None and isinstance(unconditioning, list):
|
181 |
+
_, text_embeddings = text_embeddings.chunk(2)
|
182 |
+
text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
|
183 |
+
# predict tghe noise
|
184 |
+
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
185 |
+
if guidance_scale > 1.:
|
186 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
187 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
188 |
+
# compute the previous noise sample x_t -> x_t-1
|
189 |
+
latents, pred_x0 = self.step(noise_pred, t, latents)
|
190 |
+
latents_list.append(latents)
|
191 |
+
pred_x0_list.append(pred_x0)
|
192 |
+
|
193 |
+
image = self.latent2image(latents, return_type="pt")
|
194 |
+
if return_intermediates:
|
195 |
+
pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
196 |
+
latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
|
197 |
+
return image, pred_x0_list, latents_list
|
198 |
+
return image
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def invert(
|
202 |
+
self,
|
203 |
+
image: torch.Tensor,
|
204 |
+
prompt,
|
205 |
+
num_inference_steps=50,
|
206 |
+
guidance_scale=7.5,
|
207 |
+
eta=0.0,
|
208 |
+
return_intermediates=False,
|
209 |
+
**kwds):
|
210 |
+
"""
|
211 |
+
invert a real image into noise map with determinisc DDIM inversion
|
212 |
+
"""
|
213 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
214 |
+
batch_size = image.shape[0]
|
215 |
+
if isinstance(prompt, list):
|
216 |
+
if batch_size == 1:
|
217 |
+
image = image.expand(len(prompt), -1, -1, -1)
|
218 |
+
elif isinstance(prompt, str):
|
219 |
+
if batch_size > 1:
|
220 |
+
prompt = [prompt] * batch_size
|
221 |
+
|
222 |
+
# text embeddings
|
223 |
+
text_input = self.tokenizer(
|
224 |
+
prompt,
|
225 |
+
padding="max_length",
|
226 |
+
max_length=77,
|
227 |
+
return_tensors="pt"
|
228 |
+
)
|
229 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
230 |
+
print("input text embeddings :", text_embeddings.shape)
|
231 |
+
# define initial latents
|
232 |
+
latents = self.image2latent(image)
|
233 |
+
start_latents = latents
|
234 |
+
# print(latents)
|
235 |
+
# exit()
|
236 |
+
# unconditional embedding for classifier free guidance
|
237 |
+
if guidance_scale > 1.:
|
238 |
+
max_length = text_input.input_ids.shape[-1]
|
239 |
+
unconditional_input = self.tokenizer(
|
240 |
+
[""] * batch_size,
|
241 |
+
padding="max_length",
|
242 |
+
max_length=77,
|
243 |
+
return_tensors="pt"
|
244 |
+
)
|
245 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
246 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
|
247 |
+
|
248 |
+
print("latents shape: ", latents.shape)
|
249 |
+
# interative sampling
|
250 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
251 |
+
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
252 |
+
# print("attributes: ", self.scheduler.__dict__)
|
253 |
+
latents_list = [latents]
|
254 |
+
pred_x0_list = [latents]
|
255 |
+
for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
|
256 |
+
if guidance_scale > 1.:
|
257 |
+
model_inputs = torch.cat([latents] * 2)
|
258 |
+
else:
|
259 |
+
model_inputs = latents
|
260 |
+
|
261 |
+
# predict the noise
|
262 |
+
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
263 |
+
if guidance_scale > 1.:
|
264 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
265 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
266 |
+
# compute the previous noise sample x_t-1 -> x_t
|
267 |
+
latents, pred_x0 = self.next_step(noise_pred, t, latents)
|
268 |
+
latents_list.append(latents)
|
269 |
+
pred_x0_list.append(pred_x0)
|
270 |
+
|
271 |
+
if return_intermediates:
|
272 |
+
# return the intermediate laters during inversion
|
273 |
+
# pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
274 |
+
return latents, latents_list
|
275 |
+
return latents, start_latents
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
diffusers==0.15.0
|
2 |
+
transformers
|
3 |
+
opencv-python
|
4 |
+
einops
|
5 |
+
omegaconf
|
6 |
+
pytorch_lightning
|
style.css
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
h1 {
|
2 |
+
text-align: center;
|
3 |
+
}
|
utils/__init__.py
ADDED
File without changes
|
utils/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (141 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-38.pyc
ADDED
Binary file (139 Bytes). View file
|
|
utils/__pycache__/__init__.cpython-39.pyc
ADDED
Binary file (139 Bytes). View file
|
|
utils/__pycache__/convert_from_ckpt.cpython-310.pyc
ADDED
Binary file (27.2 kB). View file
|
|
utils/__pycache__/convert_from_ckpt.cpython-38.pyc
ADDED
Binary file (28.2 kB). View file
|
|
utils/__pycache__/convert_from_ckpt.cpython-39.pyc
ADDED
Binary file (27.9 kB). View file
|
|
utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-310.pyc
ADDED
Binary file (3.36 kB). View file
|
|
utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-38.pyc
ADDED
Binary file (3.36 kB). View file
|
|
utils/__pycache__/convert_lora_safetensor_to_diffusers.cpython-39.pyc
ADDED
Binary file (3.33 kB). View file
|
|
utils/__pycache__/diffuser_utils.cpython-310.pyc
ADDED
Binary file (6.77 kB). View file
|
|
utils/__pycache__/diffuser_utils.cpython-38.pyc
ADDED
Binary file (6.8 kB). View file
|
|
utils/__pycache__/diffuser_utils.cpython-39.pyc
ADDED
Binary file (6.78 kB). View file
|
|
utils/__pycache__/free_lunch_utils.cpython-310.pyc
ADDED
Binary file (8.28 kB). View file
|
|
utils/__pycache__/free_lunch_utils.cpython-38.pyc
ADDED
Binary file (8.6 kB). View file
|
|
utils/__pycache__/free_lunch_utils.cpython-39.pyc
ADDED
Binary file (8.59 kB). View file
|
|
utils/__pycache__/masactrl_utils.cpython-310.pyc
ADDED
Binary file (6.2 kB). View file
|
|
utils/__pycache__/masactrl_utils.cpython-38.pyc
ADDED
Binary file (6.71 kB). View file
|
|
utils/__pycache__/masactrl_utils.cpython-39.pyc
ADDED
Binary file (6.69 kB). View file
|
|
utils/__pycache__/style_attn_control.cpython-310.pyc
ADDED
Binary file (8.73 kB). View file
|
|
utils/__pycache__/style_attn_control.cpython-38.pyc
ADDED
Binary file (8.67 kB). View file
|
|
utils/__pycache__/style_attn_control.cpython-39.pyc
ADDED
Binary file (8.75 kB). View file
|
|
utils/convert_from_ckpt.py
ADDED
@@ -0,0 +1,959 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023 The HuggingFace Inc. team.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
""" Conversion script for the Stable Diffusion checkpoints."""
|
16 |
+
|
17 |
+
import re
|
18 |
+
from io import BytesIO
|
19 |
+
from typing import Optional
|
20 |
+
|
21 |
+
import requests
|
22 |
+
import torch
|
23 |
+
from transformers import (
|
24 |
+
AutoFeatureExtractor,
|
25 |
+
BertTokenizerFast,
|
26 |
+
CLIPImageProcessor,
|
27 |
+
CLIPTextModel,
|
28 |
+
CLIPTextModelWithProjection,
|
29 |
+
CLIPTokenizer,
|
30 |
+
CLIPVisionConfig,
|
31 |
+
CLIPVisionModelWithProjection,
|
32 |
+
)
|
33 |
+
|
34 |
+
from diffusers.models import (
|
35 |
+
AutoencoderKL,
|
36 |
+
PriorTransformer,
|
37 |
+
UNet2DConditionModel,
|
38 |
+
)
|
39 |
+
from diffusers.schedulers import (
|
40 |
+
DDIMScheduler,
|
41 |
+
DDPMScheduler,
|
42 |
+
DPMSolverMultistepScheduler,
|
43 |
+
EulerAncestralDiscreteScheduler,
|
44 |
+
EulerDiscreteScheduler,
|
45 |
+
HeunDiscreteScheduler,
|
46 |
+
LMSDiscreteScheduler,
|
47 |
+
PNDMScheduler,
|
48 |
+
UnCLIPScheduler,
|
49 |
+
)
|
50 |
+
from diffusers.utils.import_utils import BACKENDS_MAPPING
|
51 |
+
|
52 |
+
|
53 |
+
def shave_segments(path, n_shave_prefix_segments=1):
|
54 |
+
"""
|
55 |
+
Removes segments. Positive values shave the first segments, negative shave the last segments.
|
56 |
+
"""
|
57 |
+
if n_shave_prefix_segments >= 0:
|
58 |
+
return ".".join(path.split(".")[n_shave_prefix_segments:])
|
59 |
+
else:
|
60 |
+
return ".".join(path.split(".")[:n_shave_prefix_segments])
|
61 |
+
|
62 |
+
|
63 |
+
def renew_resnet_paths(old_list, n_shave_prefix_segments=0):
|
64 |
+
"""
|
65 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
66 |
+
"""
|
67 |
+
mapping = []
|
68 |
+
for old_item in old_list:
|
69 |
+
new_item = old_item.replace("in_layers.0", "norm1")
|
70 |
+
new_item = new_item.replace("in_layers.2", "conv1")
|
71 |
+
|
72 |
+
new_item = new_item.replace("out_layers.0", "norm2")
|
73 |
+
new_item = new_item.replace("out_layers.3", "conv2")
|
74 |
+
|
75 |
+
new_item = new_item.replace("emb_layers.1", "time_emb_proj")
|
76 |
+
new_item = new_item.replace("skip_connection", "conv_shortcut")
|
77 |
+
|
78 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
79 |
+
|
80 |
+
mapping.append({"old": old_item, "new": new_item})
|
81 |
+
|
82 |
+
return mapping
|
83 |
+
|
84 |
+
|
85 |
+
def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
|
86 |
+
"""
|
87 |
+
Updates paths inside resnets to the new naming scheme (local renaming)
|
88 |
+
"""
|
89 |
+
mapping = []
|
90 |
+
for old_item in old_list:
|
91 |
+
new_item = old_item
|
92 |
+
|
93 |
+
new_item = new_item.replace("nin_shortcut", "conv_shortcut")
|
94 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
95 |
+
|
96 |
+
mapping.append({"old": old_item, "new": new_item})
|
97 |
+
|
98 |
+
return mapping
|
99 |
+
|
100 |
+
|
101 |
+
def renew_attention_paths(old_list, n_shave_prefix_segments=0):
|
102 |
+
"""
|
103 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
104 |
+
"""
|
105 |
+
mapping = []
|
106 |
+
for old_item in old_list:
|
107 |
+
new_item = old_item
|
108 |
+
|
109 |
+
# new_item = new_item.replace('norm.weight', 'group_norm.weight')
|
110 |
+
# new_item = new_item.replace('norm.bias', 'group_norm.bias')
|
111 |
+
|
112 |
+
# new_item = new_item.replace('proj_out.weight', 'proj_attn.weight')
|
113 |
+
# new_item = new_item.replace('proj_out.bias', 'proj_attn.bias')
|
114 |
+
|
115 |
+
# new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
116 |
+
|
117 |
+
mapping.append({"old": old_item, "new": new_item})
|
118 |
+
|
119 |
+
return mapping
|
120 |
+
|
121 |
+
|
122 |
+
def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
|
123 |
+
"""
|
124 |
+
Updates paths inside attentions to the new naming scheme (local renaming)
|
125 |
+
"""
|
126 |
+
mapping = []
|
127 |
+
for old_item in old_list:
|
128 |
+
new_item = old_item
|
129 |
+
|
130 |
+
new_item = new_item.replace("norm.weight", "group_norm.weight")
|
131 |
+
new_item = new_item.replace("norm.bias", "group_norm.bias")
|
132 |
+
|
133 |
+
new_item = new_item.replace("q.weight", "query.weight")
|
134 |
+
new_item = new_item.replace("q.bias", "query.bias")
|
135 |
+
|
136 |
+
new_item = new_item.replace("k.weight", "key.weight")
|
137 |
+
new_item = new_item.replace("k.bias", "key.bias")
|
138 |
+
|
139 |
+
new_item = new_item.replace("v.weight", "value.weight")
|
140 |
+
new_item = new_item.replace("v.bias", "value.bias")
|
141 |
+
|
142 |
+
new_item = new_item.replace("proj_out.weight", "proj_attn.weight")
|
143 |
+
new_item = new_item.replace("proj_out.bias", "proj_attn.bias")
|
144 |
+
|
145 |
+
new_item = shave_segments(new_item, n_shave_prefix_segments=n_shave_prefix_segments)
|
146 |
+
|
147 |
+
mapping.append({"old": old_item, "new": new_item})
|
148 |
+
|
149 |
+
return mapping
|
150 |
+
|
151 |
+
|
152 |
+
def assign_to_checkpoint(
|
153 |
+
paths, checkpoint, old_checkpoint, attention_paths_to_split=None, additional_replacements=None, config=None
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
|
157 |
+
attention layers, and takes into account additional replacements that may arise.
|
158 |
+
|
159 |
+
Assigns the weights to the new checkpoint.
|
160 |
+
"""
|
161 |
+
assert isinstance(paths, list), "Paths should be a list of dicts containing 'old' and 'new' keys."
|
162 |
+
|
163 |
+
# Splits the attention layers into three variables.
|
164 |
+
if attention_paths_to_split is not None:
|
165 |
+
for path, path_map in attention_paths_to_split.items():
|
166 |
+
old_tensor = old_checkpoint[path]
|
167 |
+
channels = old_tensor.shape[0] // 3
|
168 |
+
|
169 |
+
target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
|
170 |
+
|
171 |
+
num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
|
172 |
+
|
173 |
+
old_tensor = old_tensor.reshape((num_heads, 3 * channels // num_heads) + old_tensor.shape[1:])
|
174 |
+
query, key, value = old_tensor.split(channels // num_heads, dim=1)
|
175 |
+
|
176 |
+
checkpoint[path_map["query"]] = query.reshape(target_shape)
|
177 |
+
checkpoint[path_map["key"]] = key.reshape(target_shape)
|
178 |
+
checkpoint[path_map["value"]] = value.reshape(target_shape)
|
179 |
+
|
180 |
+
for path in paths:
|
181 |
+
new_path = path["new"]
|
182 |
+
|
183 |
+
# These have already been assigned
|
184 |
+
if attention_paths_to_split is not None and new_path in attention_paths_to_split:
|
185 |
+
continue
|
186 |
+
|
187 |
+
# Global renaming happens here
|
188 |
+
new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
|
189 |
+
new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
|
190 |
+
new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
|
191 |
+
|
192 |
+
if additional_replacements is not None:
|
193 |
+
for replacement in additional_replacements:
|
194 |
+
new_path = new_path.replace(replacement["old"], replacement["new"])
|
195 |
+
|
196 |
+
# proj_attn.weight has to be converted from conv 1D to linear
|
197 |
+
if "proj_attn.weight" in new_path:
|
198 |
+
checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
|
199 |
+
else:
|
200 |
+
checkpoint[new_path] = old_checkpoint[path["old"]]
|
201 |
+
|
202 |
+
|
203 |
+
def conv_attn_to_linear(checkpoint):
|
204 |
+
keys = list(checkpoint.keys())
|
205 |
+
attn_keys = ["query.weight", "key.weight", "value.weight"]
|
206 |
+
for key in keys:
|
207 |
+
if ".".join(key.split(".")[-2:]) in attn_keys:
|
208 |
+
if checkpoint[key].ndim > 2:
|
209 |
+
checkpoint[key] = checkpoint[key][:, :, 0, 0]
|
210 |
+
elif "proj_attn.weight" in key:
|
211 |
+
if checkpoint[key].ndim > 2:
|
212 |
+
checkpoint[key] = checkpoint[key][:, :, 0]
|
213 |
+
|
214 |
+
|
215 |
+
def create_unet_diffusers_config(original_config, image_size: int, controlnet=False):
|
216 |
+
"""
|
217 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
218 |
+
"""
|
219 |
+
if controlnet:
|
220 |
+
unet_params = original_config.model.params.control_stage_config.params
|
221 |
+
else:
|
222 |
+
unet_params = original_config.model.params.unet_config.params
|
223 |
+
|
224 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
225 |
+
|
226 |
+
block_out_channels = [unet_params.model_channels * mult for mult in unet_params.channel_mult]
|
227 |
+
|
228 |
+
down_block_types = []
|
229 |
+
resolution = 1
|
230 |
+
for i in range(len(block_out_channels)):
|
231 |
+
block_type = "CrossAttnDownBlock2D" if resolution in unet_params.attention_resolutions else "DownBlock2D"
|
232 |
+
down_block_types.append(block_type)
|
233 |
+
if i != len(block_out_channels) - 1:
|
234 |
+
resolution *= 2
|
235 |
+
|
236 |
+
up_block_types = []
|
237 |
+
for i in range(len(block_out_channels)):
|
238 |
+
block_type = "CrossAttnUpBlock2D" if resolution in unet_params.attention_resolutions else "UpBlock2D"
|
239 |
+
up_block_types.append(block_type)
|
240 |
+
resolution //= 2
|
241 |
+
|
242 |
+
vae_scale_factor = 2 ** (len(vae_params.ch_mult) - 1)
|
243 |
+
|
244 |
+
head_dim = unet_params.num_heads if "num_heads" in unet_params else None
|
245 |
+
use_linear_projection = (
|
246 |
+
unet_params.use_linear_in_transformer if "use_linear_in_transformer" in unet_params else False
|
247 |
+
)
|
248 |
+
if use_linear_projection:
|
249 |
+
# stable diffusion 2-base-512 and 2-768
|
250 |
+
if head_dim is None:
|
251 |
+
head_dim = [5, 10, 20, 20]
|
252 |
+
|
253 |
+
class_embed_type = None
|
254 |
+
projection_class_embeddings_input_dim = None
|
255 |
+
|
256 |
+
if "num_classes" in unet_params:
|
257 |
+
if unet_params.num_classes == "sequential":
|
258 |
+
class_embed_type = "projection"
|
259 |
+
assert "adm_in_channels" in unet_params
|
260 |
+
projection_class_embeddings_input_dim = unet_params.adm_in_channels
|
261 |
+
else:
|
262 |
+
raise NotImplementedError(f"Unknown conditional unet num_classes config: {unet_params.num_classes}")
|
263 |
+
|
264 |
+
config = {
|
265 |
+
"sample_size": image_size // vae_scale_factor,
|
266 |
+
"in_channels": unet_params.in_channels,
|
267 |
+
"down_block_types": tuple(down_block_types),
|
268 |
+
"block_out_channels": tuple(block_out_channels),
|
269 |
+
"layers_per_block": unet_params.num_res_blocks,
|
270 |
+
"cross_attention_dim": unet_params.context_dim,
|
271 |
+
"attention_head_dim": head_dim,
|
272 |
+
"use_linear_projection": use_linear_projection,
|
273 |
+
"class_embed_type": class_embed_type,
|
274 |
+
"projection_class_embeddings_input_dim": projection_class_embeddings_input_dim,
|
275 |
+
}
|
276 |
+
|
277 |
+
if not controlnet:
|
278 |
+
config["out_channels"] = unet_params.out_channels
|
279 |
+
config["up_block_types"] = tuple(up_block_types)
|
280 |
+
|
281 |
+
return config
|
282 |
+
|
283 |
+
|
284 |
+
def create_vae_diffusers_config(original_config, image_size: int):
|
285 |
+
"""
|
286 |
+
Creates a config for the diffusers based on the config of the LDM model.
|
287 |
+
"""
|
288 |
+
vae_params = original_config.model.params.first_stage_config.params.ddconfig
|
289 |
+
_ = original_config.model.params.first_stage_config.params.embed_dim
|
290 |
+
|
291 |
+
block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
|
292 |
+
down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
|
293 |
+
up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
|
294 |
+
|
295 |
+
config = {
|
296 |
+
"sample_size": image_size,
|
297 |
+
"in_channels": vae_params.in_channels,
|
298 |
+
"out_channels": vae_params.out_ch,
|
299 |
+
"down_block_types": tuple(down_block_types),
|
300 |
+
"up_block_types": tuple(up_block_types),
|
301 |
+
"block_out_channels": tuple(block_out_channels),
|
302 |
+
"latent_channels": vae_params.z_channels,
|
303 |
+
"layers_per_block": vae_params.num_res_blocks,
|
304 |
+
}
|
305 |
+
return config
|
306 |
+
|
307 |
+
|
308 |
+
def create_diffusers_schedular(original_config):
|
309 |
+
schedular = DDIMScheduler(
|
310 |
+
num_train_timesteps=original_config.model.params.timesteps,
|
311 |
+
beta_start=original_config.model.params.linear_start,
|
312 |
+
beta_end=original_config.model.params.linear_end,
|
313 |
+
beta_schedule="scaled_linear",
|
314 |
+
)
|
315 |
+
return schedular
|
316 |
+
|
317 |
+
|
318 |
+
def create_ldm_bert_config(original_config):
|
319 |
+
bert_params = original_config.model.parms.cond_stage_config.params
|
320 |
+
config = LDMBertConfig(
|
321 |
+
d_model=bert_params.n_embed,
|
322 |
+
encoder_layers=bert_params.n_layer,
|
323 |
+
encoder_ffn_dim=bert_params.n_embed * 4,
|
324 |
+
)
|
325 |
+
return config
|
326 |
+
|
327 |
+
|
328 |
+
def convert_ldm_unet_checkpoint(checkpoint, config, path=None, extract_ema=False, controlnet=False):
|
329 |
+
"""
|
330 |
+
Takes a state dict and a config, and returns a converted checkpoint.
|
331 |
+
"""
|
332 |
+
|
333 |
+
# extract state_dict for UNet
|
334 |
+
unet_state_dict = {}
|
335 |
+
keys = list(checkpoint.keys())
|
336 |
+
|
337 |
+
if controlnet:
|
338 |
+
unet_key = "control_model."
|
339 |
+
else:
|
340 |
+
unet_key = "model.diffusion_model."
|
341 |
+
|
342 |
+
# at least a 100 parameters have to start with `model_ema` in order for the checkpoint to be EMA
|
343 |
+
if sum(k.startswith("model_ema") for k in keys) > 100 and extract_ema:
|
344 |
+
print(f"Checkpoint {path} has both EMA and non-EMA weights.")
|
345 |
+
print(
|
346 |
+
"In this conversion only the EMA weights are extracted. If you want to instead extract the non-EMA"
|
347 |
+
" weights (useful to continue fine-tuning), please make sure to remove the `--extract_ema` flag."
|
348 |
+
)
|
349 |
+
for key in keys:
|
350 |
+
if key.startswith("model.diffusion_model"):
|
351 |
+
flat_ema_key = "model_ema." + "".join(key.split(".")[1:])
|
352 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(flat_ema_key)
|
353 |
+
else:
|
354 |
+
if sum(k.startswith("model_ema") for k in keys) > 100:
|
355 |
+
print(
|
356 |
+
"In this conversion only the non-EMA weights are extracted. If you want to instead extract the EMA"
|
357 |
+
" weights (usually better for inference), please make sure to add the `--extract_ema` flag."
|
358 |
+
)
|
359 |
+
|
360 |
+
for key in keys:
|
361 |
+
if key.startswith(unet_key):
|
362 |
+
unet_state_dict[key.replace(unet_key, "")] = checkpoint.pop(key)
|
363 |
+
|
364 |
+
new_checkpoint = {}
|
365 |
+
|
366 |
+
new_checkpoint["time_embedding.linear_1.weight"] = unet_state_dict["time_embed.0.weight"]
|
367 |
+
new_checkpoint["time_embedding.linear_1.bias"] = unet_state_dict["time_embed.0.bias"]
|
368 |
+
new_checkpoint["time_embedding.linear_2.weight"] = unet_state_dict["time_embed.2.weight"]
|
369 |
+
new_checkpoint["time_embedding.linear_2.bias"] = unet_state_dict["time_embed.2.bias"]
|
370 |
+
|
371 |
+
if config["class_embed_type"] is None:
|
372 |
+
# No parameters to port
|
373 |
+
...
|
374 |
+
elif config["class_embed_type"] == "timestep" or config["class_embed_type"] == "projection":
|
375 |
+
new_checkpoint["class_embedding.linear_1.weight"] = unet_state_dict["label_emb.0.0.weight"]
|
376 |
+
new_checkpoint["class_embedding.linear_1.bias"] = unet_state_dict["label_emb.0.0.bias"]
|
377 |
+
new_checkpoint["class_embedding.linear_2.weight"] = unet_state_dict["label_emb.0.2.weight"]
|
378 |
+
new_checkpoint["class_embedding.linear_2.bias"] = unet_state_dict["label_emb.0.2.bias"]
|
379 |
+
else:
|
380 |
+
raise NotImplementedError(f"Not implemented `class_embed_type`: {config['class_embed_type']}")
|
381 |
+
|
382 |
+
new_checkpoint["conv_in.weight"] = unet_state_dict["input_blocks.0.0.weight"]
|
383 |
+
new_checkpoint["conv_in.bias"] = unet_state_dict["input_blocks.0.0.bias"]
|
384 |
+
|
385 |
+
if not controlnet:
|
386 |
+
new_checkpoint["conv_norm_out.weight"] = unet_state_dict["out.0.weight"]
|
387 |
+
new_checkpoint["conv_norm_out.bias"] = unet_state_dict["out.0.bias"]
|
388 |
+
new_checkpoint["conv_out.weight"] = unet_state_dict["out.2.weight"]
|
389 |
+
new_checkpoint["conv_out.bias"] = unet_state_dict["out.2.bias"]
|
390 |
+
|
391 |
+
# Retrieves the keys for the input blocks only
|
392 |
+
num_input_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "input_blocks" in layer})
|
393 |
+
input_blocks = {
|
394 |
+
layer_id: [key for key in unet_state_dict if f"input_blocks.{layer_id}" in key]
|
395 |
+
for layer_id in range(num_input_blocks)
|
396 |
+
}
|
397 |
+
|
398 |
+
# Retrieves the keys for the middle blocks only
|
399 |
+
num_middle_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "middle_block" in layer})
|
400 |
+
middle_blocks = {
|
401 |
+
layer_id: [key for key in unet_state_dict if f"middle_block.{layer_id}" in key]
|
402 |
+
for layer_id in range(num_middle_blocks)
|
403 |
+
}
|
404 |
+
|
405 |
+
# Retrieves the keys for the output blocks only
|
406 |
+
num_output_blocks = len({".".join(layer.split(".")[:2]) for layer in unet_state_dict if "output_blocks" in layer})
|
407 |
+
output_blocks = {
|
408 |
+
layer_id: [key for key in unet_state_dict if f"output_blocks.{layer_id}" in key]
|
409 |
+
for layer_id in range(num_output_blocks)
|
410 |
+
}
|
411 |
+
|
412 |
+
for i in range(1, num_input_blocks):
|
413 |
+
block_id = (i - 1) // (config["layers_per_block"] + 1)
|
414 |
+
layer_in_block_id = (i - 1) % (config["layers_per_block"] + 1)
|
415 |
+
|
416 |
+
resnets = [
|
417 |
+
key for key in input_blocks[i] if f"input_blocks.{i}.0" in key and f"input_blocks.{i}.0.op" not in key
|
418 |
+
]
|
419 |
+
attentions = [key for key in input_blocks[i] if f"input_blocks.{i}.1" in key]
|
420 |
+
|
421 |
+
if f"input_blocks.{i}.0.op.weight" in unet_state_dict:
|
422 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.weight"] = unet_state_dict.pop(
|
423 |
+
f"input_blocks.{i}.0.op.weight"
|
424 |
+
)
|
425 |
+
new_checkpoint[f"down_blocks.{block_id}.downsamplers.0.conv.bias"] = unet_state_dict.pop(
|
426 |
+
f"input_blocks.{i}.0.op.bias"
|
427 |
+
)
|
428 |
+
|
429 |
+
paths = renew_resnet_paths(resnets)
|
430 |
+
meta_path = {"old": f"input_blocks.{i}.0", "new": f"down_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
431 |
+
assign_to_checkpoint(
|
432 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
433 |
+
)
|
434 |
+
|
435 |
+
if len(attentions):
|
436 |
+
paths = renew_attention_paths(attentions)
|
437 |
+
meta_path = {"old": f"input_blocks.{i}.1", "new": f"down_blocks.{block_id}.attentions.{layer_in_block_id}"}
|
438 |
+
assign_to_checkpoint(
|
439 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
440 |
+
)
|
441 |
+
|
442 |
+
resnet_0 = middle_blocks[0]
|
443 |
+
attentions = middle_blocks[1]
|
444 |
+
resnet_1 = middle_blocks[2]
|
445 |
+
|
446 |
+
resnet_0_paths = renew_resnet_paths(resnet_0)
|
447 |
+
assign_to_checkpoint(resnet_0_paths, new_checkpoint, unet_state_dict, config=config)
|
448 |
+
|
449 |
+
resnet_1_paths = renew_resnet_paths(resnet_1)
|
450 |
+
assign_to_checkpoint(resnet_1_paths, new_checkpoint, unet_state_dict, config=config)
|
451 |
+
|
452 |
+
attentions_paths = renew_attention_paths(attentions)
|
453 |
+
meta_path = {"old": "middle_block.1", "new": "mid_block.attentions.0"}
|
454 |
+
assign_to_checkpoint(
|
455 |
+
attentions_paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
456 |
+
)
|
457 |
+
|
458 |
+
for i in range(num_output_blocks):
|
459 |
+
block_id = i // (config["layers_per_block"] + 1)
|
460 |
+
layer_in_block_id = i % (config["layers_per_block"] + 1)
|
461 |
+
output_block_layers = [shave_segments(name, 2) for name in output_blocks[i]]
|
462 |
+
output_block_list = {}
|
463 |
+
|
464 |
+
for layer in output_block_layers:
|
465 |
+
layer_id, layer_name = layer.split(".")[0], shave_segments(layer, 1)
|
466 |
+
if layer_id in output_block_list:
|
467 |
+
output_block_list[layer_id].append(layer_name)
|
468 |
+
else:
|
469 |
+
output_block_list[layer_id] = [layer_name]
|
470 |
+
|
471 |
+
if len(output_block_list) > 1:
|
472 |
+
resnets = [key for key in output_blocks[i] if f"output_blocks.{i}.0" in key]
|
473 |
+
attentions = [key for key in output_blocks[i] if f"output_blocks.{i}.1" in key]
|
474 |
+
|
475 |
+
resnet_0_paths = renew_resnet_paths(resnets)
|
476 |
+
paths = renew_resnet_paths(resnets)
|
477 |
+
|
478 |
+
meta_path = {"old": f"output_blocks.{i}.0", "new": f"up_blocks.{block_id}.resnets.{layer_in_block_id}"}
|
479 |
+
assign_to_checkpoint(
|
480 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
481 |
+
)
|
482 |
+
|
483 |
+
output_block_list = {k: sorted(v) for k, v in output_block_list.items()}
|
484 |
+
if ["conv.bias", "conv.weight"] in output_block_list.values():
|
485 |
+
index = list(output_block_list.values()).index(["conv.bias", "conv.weight"])
|
486 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.weight"] = unet_state_dict[
|
487 |
+
f"output_blocks.{i}.{index}.conv.weight"
|
488 |
+
]
|
489 |
+
new_checkpoint[f"up_blocks.{block_id}.upsamplers.0.conv.bias"] = unet_state_dict[
|
490 |
+
f"output_blocks.{i}.{index}.conv.bias"
|
491 |
+
]
|
492 |
+
|
493 |
+
# Clear attentions as they have been attributed above.
|
494 |
+
if len(attentions) == 2:
|
495 |
+
attentions = []
|
496 |
+
|
497 |
+
if len(attentions):
|
498 |
+
paths = renew_attention_paths(attentions)
|
499 |
+
meta_path = {
|
500 |
+
"old": f"output_blocks.{i}.1",
|
501 |
+
"new": f"up_blocks.{block_id}.attentions.{layer_in_block_id}",
|
502 |
+
}
|
503 |
+
assign_to_checkpoint(
|
504 |
+
paths, new_checkpoint, unet_state_dict, additional_replacements=[meta_path], config=config
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
resnet_0_paths = renew_resnet_paths(output_block_layers, n_shave_prefix_segments=1)
|
508 |
+
for path in resnet_0_paths:
|
509 |
+
old_path = ".".join(["output_blocks", str(i), path["old"]])
|
510 |
+
new_path = ".".join(["up_blocks", str(block_id), "resnets", str(layer_in_block_id), path["new"]])
|
511 |
+
|
512 |
+
new_checkpoint[new_path] = unet_state_dict[old_path]
|
513 |
+
|
514 |
+
if controlnet:
|
515 |
+
# conditioning embedding
|
516 |
+
|
517 |
+
orig_index = 0
|
518 |
+
|
519 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.weight"] = unet_state_dict.pop(
|
520 |
+
f"input_hint_block.{orig_index}.weight"
|
521 |
+
)
|
522 |
+
new_checkpoint["controlnet_cond_embedding.conv_in.bias"] = unet_state_dict.pop(
|
523 |
+
f"input_hint_block.{orig_index}.bias"
|
524 |
+
)
|
525 |
+
|
526 |
+
orig_index += 2
|
527 |
+
|
528 |
+
diffusers_index = 0
|
529 |
+
|
530 |
+
while diffusers_index < 6:
|
531 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.weight"] = unet_state_dict.pop(
|
532 |
+
f"input_hint_block.{orig_index}.weight"
|
533 |
+
)
|
534 |
+
new_checkpoint[f"controlnet_cond_embedding.blocks.{diffusers_index}.bias"] = unet_state_dict.pop(
|
535 |
+
f"input_hint_block.{orig_index}.bias"
|
536 |
+
)
|
537 |
+
diffusers_index += 1
|
538 |
+
orig_index += 2
|
539 |
+
|
540 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.weight"] = unet_state_dict.pop(
|
541 |
+
f"input_hint_block.{orig_index}.weight"
|
542 |
+
)
|
543 |
+
new_checkpoint["controlnet_cond_embedding.conv_out.bias"] = unet_state_dict.pop(
|
544 |
+
f"input_hint_block.{orig_index}.bias"
|
545 |
+
)
|
546 |
+
|
547 |
+
# down blocks
|
548 |
+
for i in range(num_input_blocks):
|
549 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.weight"] = unet_state_dict.pop(f"zero_convs.{i}.0.weight")
|
550 |
+
new_checkpoint[f"controlnet_down_blocks.{i}.bias"] = unet_state_dict.pop(f"zero_convs.{i}.0.bias")
|
551 |
+
|
552 |
+
# mid block
|
553 |
+
new_checkpoint["controlnet_mid_block.weight"] = unet_state_dict.pop("middle_block_out.0.weight")
|
554 |
+
new_checkpoint["controlnet_mid_block.bias"] = unet_state_dict.pop("middle_block_out.0.bias")
|
555 |
+
|
556 |
+
return new_checkpoint
|
557 |
+
|
558 |
+
|
559 |
+
def convert_ldm_vae_checkpoint(checkpoint, config):
|
560 |
+
# extract state dict for VAE
|
561 |
+
vae_state_dict = {}
|
562 |
+
vae_key = "first_stage_model."
|
563 |
+
keys = list(checkpoint.keys())
|
564 |
+
for key in keys:
|
565 |
+
if key.startswith(vae_key):
|
566 |
+
vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
|
567 |
+
|
568 |
+
new_checkpoint = {}
|
569 |
+
|
570 |
+
new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
|
571 |
+
new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
|
572 |
+
new_checkpoint["encoder.conv_out.weight"] = vae_state_dict["encoder.conv_out.weight"]
|
573 |
+
new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
|
574 |
+
new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict["encoder.norm_out.weight"]
|
575 |
+
new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict["encoder.norm_out.bias"]
|
576 |
+
|
577 |
+
new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
|
578 |
+
new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
|
579 |
+
new_checkpoint["decoder.conv_out.weight"] = vae_state_dict["decoder.conv_out.weight"]
|
580 |
+
new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
|
581 |
+
new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict["decoder.norm_out.weight"]
|
582 |
+
new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict["decoder.norm_out.bias"]
|
583 |
+
|
584 |
+
new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
|
585 |
+
new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
|
586 |
+
new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
|
587 |
+
new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
|
588 |
+
|
589 |
+
# Retrieves the keys for the encoder down blocks only
|
590 |
+
num_down_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "encoder.down" in layer})
|
591 |
+
down_blocks = {
|
592 |
+
layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key] for layer_id in range(num_down_blocks)
|
593 |
+
}
|
594 |
+
|
595 |
+
# Retrieves the keys for the decoder up blocks only
|
596 |
+
num_up_blocks = len({".".join(layer.split(".")[:3]) for layer in vae_state_dict if "decoder.up" in layer})
|
597 |
+
up_blocks = {
|
598 |
+
layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key] for layer_id in range(num_up_blocks)
|
599 |
+
}
|
600 |
+
|
601 |
+
for i in range(num_down_blocks):
|
602 |
+
resnets = [key for key in down_blocks[i] if f"down.{i}" in key and f"down.{i}.downsample" not in key]
|
603 |
+
|
604 |
+
if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
|
605 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"] = vae_state_dict.pop(
|
606 |
+
f"encoder.down.{i}.downsample.conv.weight"
|
607 |
+
)
|
608 |
+
new_checkpoint[f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"] = vae_state_dict.pop(
|
609 |
+
f"encoder.down.{i}.downsample.conv.bias"
|
610 |
+
)
|
611 |
+
|
612 |
+
paths = renew_vae_resnet_paths(resnets)
|
613 |
+
meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
|
614 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
615 |
+
|
616 |
+
mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
|
617 |
+
num_mid_res_blocks = 2
|
618 |
+
for i in range(1, num_mid_res_blocks + 1):
|
619 |
+
resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
|
620 |
+
|
621 |
+
paths = renew_vae_resnet_paths(resnets)
|
622 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
623 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
624 |
+
|
625 |
+
mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
|
626 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
627 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
628 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
629 |
+
conv_attn_to_linear(new_checkpoint)
|
630 |
+
|
631 |
+
for i in range(num_up_blocks):
|
632 |
+
block_id = num_up_blocks - 1 - i
|
633 |
+
resnets = [
|
634 |
+
key for key in up_blocks[block_id] if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
|
635 |
+
]
|
636 |
+
|
637 |
+
if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
|
638 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"] = vae_state_dict[
|
639 |
+
f"decoder.up.{block_id}.upsample.conv.weight"
|
640 |
+
]
|
641 |
+
new_checkpoint[f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"] = vae_state_dict[
|
642 |
+
f"decoder.up.{block_id}.upsample.conv.bias"
|
643 |
+
]
|
644 |
+
|
645 |
+
paths = renew_vae_resnet_paths(resnets)
|
646 |
+
meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
|
647 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
648 |
+
|
649 |
+
mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
|
650 |
+
num_mid_res_blocks = 2
|
651 |
+
for i in range(1, num_mid_res_blocks + 1):
|
652 |
+
resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
|
653 |
+
|
654 |
+
paths = renew_vae_resnet_paths(resnets)
|
655 |
+
meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
|
656 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
657 |
+
|
658 |
+
mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
|
659 |
+
paths = renew_vae_attention_paths(mid_attentions)
|
660 |
+
meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
|
661 |
+
assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path], config=config)
|
662 |
+
conv_attn_to_linear(new_checkpoint)
|
663 |
+
return new_checkpoint
|
664 |
+
|
665 |
+
|
666 |
+
def convert_ldm_bert_checkpoint(checkpoint, config):
|
667 |
+
def _copy_attn_layer(hf_attn_layer, pt_attn_layer):
|
668 |
+
hf_attn_layer.q_proj.weight.data = pt_attn_layer.to_q.weight
|
669 |
+
hf_attn_layer.k_proj.weight.data = pt_attn_layer.to_k.weight
|
670 |
+
hf_attn_layer.v_proj.weight.data = pt_attn_layer.to_v.weight
|
671 |
+
|
672 |
+
hf_attn_layer.out_proj.weight = pt_attn_layer.to_out.weight
|
673 |
+
hf_attn_layer.out_proj.bias = pt_attn_layer.to_out.bias
|
674 |
+
|
675 |
+
def _copy_linear(hf_linear, pt_linear):
|
676 |
+
hf_linear.weight = pt_linear.weight
|
677 |
+
hf_linear.bias = pt_linear.bias
|
678 |
+
|
679 |
+
def _copy_layer(hf_layer, pt_layer):
|
680 |
+
# copy layer norms
|
681 |
+
_copy_linear(hf_layer.self_attn_layer_norm, pt_layer[0][0])
|
682 |
+
_copy_linear(hf_layer.final_layer_norm, pt_layer[1][0])
|
683 |
+
|
684 |
+
# copy attn
|
685 |
+
_copy_attn_layer(hf_layer.self_attn, pt_layer[0][1])
|
686 |
+
|
687 |
+
# copy MLP
|
688 |
+
pt_mlp = pt_layer[1][1]
|
689 |
+
_copy_linear(hf_layer.fc1, pt_mlp.net[0][0])
|
690 |
+
_copy_linear(hf_layer.fc2, pt_mlp.net[2])
|
691 |
+
|
692 |
+
def _copy_layers(hf_layers, pt_layers):
|
693 |
+
for i, hf_layer in enumerate(hf_layers):
|
694 |
+
if i != 0:
|
695 |
+
i += i
|
696 |
+
pt_layer = pt_layers[i : i + 2]
|
697 |
+
_copy_layer(hf_layer, pt_layer)
|
698 |
+
|
699 |
+
hf_model = LDMBertModel(config).eval()
|
700 |
+
|
701 |
+
# copy embeds
|
702 |
+
hf_model.model.embed_tokens.weight = checkpoint.transformer.token_emb.weight
|
703 |
+
hf_model.model.embed_positions.weight.data = checkpoint.transformer.pos_emb.emb.weight
|
704 |
+
|
705 |
+
# copy layer norm
|
706 |
+
_copy_linear(hf_model.model.layer_norm, checkpoint.transformer.norm)
|
707 |
+
|
708 |
+
# copy hidden layers
|
709 |
+
_copy_layers(hf_model.model.layers, checkpoint.transformer.attn_layers.layers)
|
710 |
+
|
711 |
+
_copy_linear(hf_model.to_logits, checkpoint.transformer.to_logits)
|
712 |
+
|
713 |
+
return hf_model
|
714 |
+
|
715 |
+
|
716 |
+
def convert_ldm_clip_checkpoint(checkpoint):
|
717 |
+
text_model = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
|
718 |
+
keys = list(checkpoint.keys())
|
719 |
+
|
720 |
+
text_model_dict = {}
|
721 |
+
|
722 |
+
for key in keys:
|
723 |
+
if key.startswith("cond_stage_model.transformer"):
|
724 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
725 |
+
|
726 |
+
text_model.load_state_dict(text_model_dict)
|
727 |
+
|
728 |
+
return text_model
|
729 |
+
|
730 |
+
|
731 |
+
textenc_conversion_lst = [
|
732 |
+
("cond_stage_model.model.positional_embedding", "text_model.embeddings.position_embedding.weight"),
|
733 |
+
("cond_stage_model.model.token_embedding.weight", "text_model.embeddings.token_embedding.weight"),
|
734 |
+
("cond_stage_model.model.ln_final.weight", "text_model.final_layer_norm.weight"),
|
735 |
+
("cond_stage_model.model.ln_final.bias", "text_model.final_layer_norm.bias"),
|
736 |
+
]
|
737 |
+
textenc_conversion_map = {x[0]: x[1] for x in textenc_conversion_lst}
|
738 |
+
|
739 |
+
textenc_transformer_conversion_lst = [
|
740 |
+
# (stable-diffusion, HF Diffusers)
|
741 |
+
("resblocks.", "text_model.encoder.layers."),
|
742 |
+
("ln_1", "layer_norm1"),
|
743 |
+
("ln_2", "layer_norm2"),
|
744 |
+
(".c_fc.", ".fc1."),
|
745 |
+
(".c_proj.", ".fc2."),
|
746 |
+
(".attn", ".self_attn"),
|
747 |
+
("ln_final.", "transformer.text_model.final_layer_norm."),
|
748 |
+
("token_embedding.weight", "transformer.text_model.embeddings.token_embedding.weight"),
|
749 |
+
("positional_embedding", "transformer.text_model.embeddings.position_embedding.weight"),
|
750 |
+
]
|
751 |
+
protected = {re.escape(x[0]): x[1] for x in textenc_transformer_conversion_lst}
|
752 |
+
textenc_pattern = re.compile("|".join(protected.keys()))
|
753 |
+
|
754 |
+
|
755 |
+
def convert_paint_by_example_checkpoint(checkpoint):
|
756 |
+
config = CLIPVisionConfig.from_pretrained("openai/clip-vit-large-patch14")
|
757 |
+
model = PaintByExampleImageEncoder(config)
|
758 |
+
|
759 |
+
keys = list(checkpoint.keys())
|
760 |
+
|
761 |
+
text_model_dict = {}
|
762 |
+
|
763 |
+
for key in keys:
|
764 |
+
if key.startswith("cond_stage_model.transformer"):
|
765 |
+
text_model_dict[key[len("cond_stage_model.transformer.") :]] = checkpoint[key]
|
766 |
+
|
767 |
+
# load clip vision
|
768 |
+
model.model.load_state_dict(text_model_dict)
|
769 |
+
|
770 |
+
# load mapper
|
771 |
+
keys_mapper = {
|
772 |
+
k[len("cond_stage_model.mapper.res") :]: v
|
773 |
+
for k, v in checkpoint.items()
|
774 |
+
if k.startswith("cond_stage_model.mapper")
|
775 |
+
}
|
776 |
+
|
777 |
+
MAPPING = {
|
778 |
+
"attn.c_qkv": ["attn1.to_q", "attn1.to_k", "attn1.to_v"],
|
779 |
+
"attn.c_proj": ["attn1.to_out.0"],
|
780 |
+
"ln_1": ["norm1"],
|
781 |
+
"ln_2": ["norm3"],
|
782 |
+
"mlp.c_fc": ["ff.net.0.proj"],
|
783 |
+
"mlp.c_proj": ["ff.net.2"],
|
784 |
+
}
|
785 |
+
|
786 |
+
mapped_weights = {}
|
787 |
+
for key, value in keys_mapper.items():
|
788 |
+
prefix = key[: len("blocks.i")]
|
789 |
+
suffix = key.split(prefix)[-1].split(".")[-1]
|
790 |
+
name = key.split(prefix)[-1].split(suffix)[0][1:-1]
|
791 |
+
mapped_names = MAPPING[name]
|
792 |
+
|
793 |
+
num_splits = len(mapped_names)
|
794 |
+
for i, mapped_name in enumerate(mapped_names):
|
795 |
+
new_name = ".".join([prefix, mapped_name, suffix])
|
796 |
+
shape = value.shape[0] // num_splits
|
797 |
+
mapped_weights[new_name] = value[i * shape : (i + 1) * shape]
|
798 |
+
|
799 |
+
model.mapper.load_state_dict(mapped_weights)
|
800 |
+
|
801 |
+
# load final layer norm
|
802 |
+
model.final_layer_norm.load_state_dict(
|
803 |
+
{
|
804 |
+
"bias": checkpoint["cond_stage_model.final_ln.bias"],
|
805 |
+
"weight": checkpoint["cond_stage_model.final_ln.weight"],
|
806 |
+
}
|
807 |
+
)
|
808 |
+
|
809 |
+
# load final proj
|
810 |
+
model.proj_out.load_state_dict(
|
811 |
+
{
|
812 |
+
"bias": checkpoint["proj_out.bias"],
|
813 |
+
"weight": checkpoint["proj_out.weight"],
|
814 |
+
}
|
815 |
+
)
|
816 |
+
|
817 |
+
# load uncond vector
|
818 |
+
model.uncond_vector.data = torch.nn.Parameter(checkpoint["learnable_vector"])
|
819 |
+
return model
|
820 |
+
|
821 |
+
|
822 |
+
def convert_open_clip_checkpoint(checkpoint):
|
823 |
+
text_model = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2", subfolder="text_encoder")
|
824 |
+
|
825 |
+
keys = list(checkpoint.keys())
|
826 |
+
|
827 |
+
text_model_dict = {}
|
828 |
+
|
829 |
+
if "cond_stage_model.model.text_projection" in checkpoint:
|
830 |
+
d_model = int(checkpoint["cond_stage_model.model.text_projection"].shape[0])
|
831 |
+
else:
|
832 |
+
d_model = 1024
|
833 |
+
|
834 |
+
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
835 |
+
|
836 |
+
for key in keys:
|
837 |
+
if "resblocks.23" in key: # Diffusers drops the final layer and only uses the penultimate layer
|
838 |
+
continue
|
839 |
+
if key in textenc_conversion_map:
|
840 |
+
text_model_dict[textenc_conversion_map[key]] = checkpoint[key]
|
841 |
+
if key.startswith("cond_stage_model.model.transformer."):
|
842 |
+
new_key = key[len("cond_stage_model.model.transformer.") :]
|
843 |
+
if new_key.endswith(".in_proj_weight"):
|
844 |
+
new_key = new_key[: -len(".in_proj_weight")]
|
845 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
846 |
+
text_model_dict[new_key + ".q_proj.weight"] = checkpoint[key][:d_model, :]
|
847 |
+
text_model_dict[new_key + ".k_proj.weight"] = checkpoint[key][d_model : d_model * 2, :]
|
848 |
+
text_model_dict[new_key + ".v_proj.weight"] = checkpoint[key][d_model * 2 :, :]
|
849 |
+
elif new_key.endswith(".in_proj_bias"):
|
850 |
+
new_key = new_key[: -len(".in_proj_bias")]
|
851 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
852 |
+
text_model_dict[new_key + ".q_proj.bias"] = checkpoint[key][:d_model]
|
853 |
+
text_model_dict[new_key + ".k_proj.bias"] = checkpoint[key][d_model : d_model * 2]
|
854 |
+
text_model_dict[new_key + ".v_proj.bias"] = checkpoint[key][d_model * 2 :]
|
855 |
+
else:
|
856 |
+
new_key = textenc_pattern.sub(lambda m: protected[re.escape(m.group(0))], new_key)
|
857 |
+
|
858 |
+
text_model_dict[new_key] = checkpoint[key]
|
859 |
+
|
860 |
+
text_model.load_state_dict(text_model_dict)
|
861 |
+
|
862 |
+
return text_model
|
863 |
+
|
864 |
+
|
865 |
+
def stable_unclip_image_encoder(original_config):
|
866 |
+
"""
|
867 |
+
Returns the image processor and clip image encoder for the img2img unclip pipeline.
|
868 |
+
|
869 |
+
We currently know of two types of stable unclip models which separately use the clip and the openclip image
|
870 |
+
encoders.
|
871 |
+
"""
|
872 |
+
|
873 |
+
image_embedder_config = original_config.model.params.embedder_config
|
874 |
+
|
875 |
+
sd_clip_image_embedder_class = image_embedder_config.target
|
876 |
+
sd_clip_image_embedder_class = sd_clip_image_embedder_class.split(".")[-1]
|
877 |
+
|
878 |
+
if sd_clip_image_embedder_class == "ClipImageEmbedder":
|
879 |
+
clip_model_name = image_embedder_config.params.model
|
880 |
+
|
881 |
+
if clip_model_name == "ViT-L/14":
|
882 |
+
feature_extractor = CLIPImageProcessor()
|
883 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("openai/clip-vit-large-patch14")
|
884 |
+
else:
|
885 |
+
raise NotImplementedError(f"Unknown CLIP checkpoint name in stable diffusion checkpoint {clip_model_name}")
|
886 |
+
|
887 |
+
elif sd_clip_image_embedder_class == "FrozenOpenCLIPImageEmbedder":
|
888 |
+
feature_extractor = CLIPImageProcessor()
|
889 |
+
image_encoder = CLIPVisionModelWithProjection.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
|
890 |
+
else:
|
891 |
+
raise NotImplementedError(
|
892 |
+
f"Unknown CLIP image embedder class in stable diffusion checkpoint {sd_clip_image_embedder_class}"
|
893 |
+
)
|
894 |
+
|
895 |
+
return feature_extractor, image_encoder
|
896 |
+
|
897 |
+
|
898 |
+
def stable_unclip_image_noising_components(
|
899 |
+
original_config, clip_stats_path: Optional[str] = None, device: Optional[str] = None
|
900 |
+
):
|
901 |
+
"""
|
902 |
+
Returns the noising components for the img2img and txt2img unclip pipelines.
|
903 |
+
|
904 |
+
Converts the stability noise augmentor into
|
905 |
+
1. a `StableUnCLIPImageNormalizer` for holding the CLIP stats
|
906 |
+
2. a `DDPMScheduler` for holding the noise schedule
|
907 |
+
|
908 |
+
If the noise augmentor config specifies a clip stats path, the `clip_stats_path` must be provided.
|
909 |
+
"""
|
910 |
+
noise_aug_config = original_config.model.params.noise_aug_config
|
911 |
+
noise_aug_class = noise_aug_config.target
|
912 |
+
noise_aug_class = noise_aug_class.split(".")[-1]
|
913 |
+
|
914 |
+
if noise_aug_class == "CLIPEmbeddingNoiseAugmentation":
|
915 |
+
noise_aug_config = noise_aug_config.params
|
916 |
+
embedding_dim = noise_aug_config.timestep_dim
|
917 |
+
max_noise_level = noise_aug_config.noise_schedule_config.timesteps
|
918 |
+
beta_schedule = noise_aug_config.noise_schedule_config.beta_schedule
|
919 |
+
|
920 |
+
image_normalizer = StableUnCLIPImageNormalizer(embedding_dim=embedding_dim)
|
921 |
+
image_noising_scheduler = DDPMScheduler(num_train_timesteps=max_noise_level, beta_schedule=beta_schedule)
|
922 |
+
|
923 |
+
if "clip_stats_path" in noise_aug_config:
|
924 |
+
if clip_stats_path is None:
|
925 |
+
raise ValueError("This stable unclip config requires a `clip_stats_path`")
|
926 |
+
|
927 |
+
clip_mean, clip_std = torch.load(clip_stats_path, map_location=device)
|
928 |
+
clip_mean = clip_mean[None, :]
|
929 |
+
clip_std = clip_std[None, :]
|
930 |
+
|
931 |
+
clip_stats_state_dict = {
|
932 |
+
"mean": clip_mean,
|
933 |
+
"std": clip_std,
|
934 |
+
}
|
935 |
+
|
936 |
+
image_normalizer.load_state_dict(clip_stats_state_dict)
|
937 |
+
else:
|
938 |
+
raise NotImplementedError(f"Unknown noise augmentor class: {noise_aug_class}")
|
939 |
+
|
940 |
+
return image_normalizer, image_noising_scheduler
|
941 |
+
|
942 |
+
|
943 |
+
def convert_controlnet_checkpoint(
|
944 |
+
checkpoint, original_config, checkpoint_path, image_size, upcast_attention, extract_ema
|
945 |
+
):
|
946 |
+
ctrlnet_config = create_unet_diffusers_config(original_config, image_size=image_size, controlnet=True)
|
947 |
+
ctrlnet_config["upcast_attention"] = upcast_attention
|
948 |
+
|
949 |
+
ctrlnet_config.pop("sample_size")
|
950 |
+
|
951 |
+
controlnet_model = ControlNetModel(**ctrlnet_config)
|
952 |
+
|
953 |
+
converted_ctrl_checkpoint = convert_ldm_unet_checkpoint(
|
954 |
+
checkpoint, ctrlnet_config, path=checkpoint_path, extract_ema=extract_ema, controlnet=True
|
955 |
+
)
|
956 |
+
|
957 |
+
controlnet_model.load_state_dict(converted_ctrl_checkpoint)
|
958 |
+
|
959 |
+
return controlnet_model
|
utils/convert_lora_safetensor_to_diffusers.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# coding=utf-8
|
2 |
+
# Copyright 2023, Haofan Wang, Qixun Wang, All rights reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
""" Conversion script for the LoRA's safetensors checkpoints. """
|
17 |
+
|
18 |
+
import argparse
|
19 |
+
|
20 |
+
import torch
|
21 |
+
from safetensors.torch import load_file
|
22 |
+
|
23 |
+
from diffusers import StableDiffusionPipeline
|
24 |
+
import pdb
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
def convert_motion_lora_ckpt_to_diffusers(pipeline, state_dict, alpha=1.0):
|
29 |
+
# directly update weight in diffusers model
|
30 |
+
for key in state_dict:
|
31 |
+
# only process lora down key
|
32 |
+
if "up." in key: continue
|
33 |
+
|
34 |
+
up_key = key.replace(".down.", ".up.")
|
35 |
+
model_key = key.replace("processor.", "").replace("_lora", "").replace("down.", "").replace("up.", "")
|
36 |
+
model_key = model_key.replace("to_out.", "to_out.0.")
|
37 |
+
layer_infos = model_key.split(".")[:-1]
|
38 |
+
|
39 |
+
curr_layer = pipeline.unet
|
40 |
+
while len(layer_infos) > 0:
|
41 |
+
temp_name = layer_infos.pop(0)
|
42 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
43 |
+
|
44 |
+
weight_down = state_dict[key]
|
45 |
+
weight_up = state_dict[up_key]
|
46 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
47 |
+
|
48 |
+
return pipeline
|
49 |
+
|
50 |
+
|
51 |
+
|
52 |
+
def convert_lora(pipeline, state_dict, LORA_PREFIX_UNET="lora_unet", LORA_PREFIX_TEXT_ENCODER="lora_te", alpha=0.6):
|
53 |
+
# load base model
|
54 |
+
# pipeline = StableDiffusionPipeline.from_pretrained(base_model_path, torch_dtype=torch.float32)
|
55 |
+
|
56 |
+
# load LoRA weight from .safetensors
|
57 |
+
# state_dict = load_file(checkpoint_path)
|
58 |
+
|
59 |
+
visited = []
|
60 |
+
|
61 |
+
# directly update weight in diffusers model
|
62 |
+
for key in state_dict:
|
63 |
+
# it is suggested to print out the key, it usually will be something like below
|
64 |
+
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
|
65 |
+
|
66 |
+
# as we have set the alpha beforehand, so just skip
|
67 |
+
if ".alpha" in key or key in visited:
|
68 |
+
continue
|
69 |
+
|
70 |
+
if "text" in key:
|
71 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_TEXT_ENCODER + "_")[-1].split("_")
|
72 |
+
curr_layer = pipeline.text_encoder
|
73 |
+
else:
|
74 |
+
layer_infos = key.split(".")[0].split(LORA_PREFIX_UNET + "_")[-1].split("_")
|
75 |
+
curr_layer = pipeline.unet
|
76 |
+
|
77 |
+
# find the target layer
|
78 |
+
temp_name = layer_infos.pop(0)
|
79 |
+
while len(layer_infos) > -1:
|
80 |
+
try:
|
81 |
+
curr_layer = curr_layer.__getattr__(temp_name)
|
82 |
+
if len(layer_infos) > 0:
|
83 |
+
temp_name = layer_infos.pop(0)
|
84 |
+
elif len(layer_infos) == 0:
|
85 |
+
break
|
86 |
+
except Exception:
|
87 |
+
if len(temp_name) > 0:
|
88 |
+
temp_name += "_" + layer_infos.pop(0)
|
89 |
+
else:
|
90 |
+
temp_name = layer_infos.pop(0)
|
91 |
+
|
92 |
+
pair_keys = []
|
93 |
+
if "lora_down" in key:
|
94 |
+
pair_keys.append(key.replace("lora_down", "lora_up"))
|
95 |
+
pair_keys.append(key)
|
96 |
+
else:
|
97 |
+
pair_keys.append(key)
|
98 |
+
pair_keys.append(key.replace("lora_up", "lora_down"))
|
99 |
+
|
100 |
+
# update weight
|
101 |
+
if len(state_dict[pair_keys[0]].shape) == 4:
|
102 |
+
weight_up = state_dict[pair_keys[0]].squeeze(3).squeeze(2).to(torch.float32)
|
103 |
+
weight_down = state_dict[pair_keys[1]].squeeze(3).squeeze(2).to(torch.float32)
|
104 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).unsqueeze(2).unsqueeze(3).to(curr_layer.weight.data.device)
|
105 |
+
else:
|
106 |
+
weight_up = state_dict[pair_keys[0]].to(torch.float32)
|
107 |
+
weight_down = state_dict[pair_keys[1]].to(torch.float32)
|
108 |
+
curr_layer.weight.data += alpha * torch.mm(weight_up, weight_down).to(curr_layer.weight.data.device)
|
109 |
+
|
110 |
+
# update visited list
|
111 |
+
for item in pair_keys:
|
112 |
+
visited.append(item)
|
113 |
+
|
114 |
+
return pipeline
|
115 |
+
|
116 |
+
|
117 |
+
if __name__ == "__main__":
|
118 |
+
parser = argparse.ArgumentParser()
|
119 |
+
|
120 |
+
parser.add_argument(
|
121 |
+
"--base_model_path", default=None, type=str, required=True, help="Path to the base model in diffusers format."
|
122 |
+
)
|
123 |
+
parser.add_argument(
|
124 |
+
"--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert."
|
125 |
+
)
|
126 |
+
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.")
|
127 |
+
parser.add_argument(
|
128 |
+
"--lora_prefix_unet", default="lora_unet", type=str, help="The prefix of UNet weight in safetensors"
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--lora_prefix_text_encoder",
|
132 |
+
default="lora_te",
|
133 |
+
type=str,
|
134 |
+
help="The prefix of text encoder weight in safetensors",
|
135 |
+
)
|
136 |
+
parser.add_argument("--alpha", default=0.75, type=float, help="The merging ratio in W = W0 + alpha * deltaW")
|
137 |
+
parser.add_argument(
|
138 |
+
"--to_safetensors", action="store_true", help="Whether to store pipeline in safetensors format or not."
|
139 |
+
)
|
140 |
+
parser.add_argument("--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)")
|
141 |
+
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
base_model_path = args.base_model_path
|
145 |
+
checkpoint_path = args.checkpoint_path
|
146 |
+
dump_path = args.dump_path
|
147 |
+
lora_prefix_unet = args.lora_prefix_unet
|
148 |
+
lora_prefix_text_encoder = args.lora_prefix_text_encoder
|
149 |
+
alpha = args.alpha
|
150 |
+
|
151 |
+
pipe = convert(base_model_path, checkpoint_path, lora_prefix_unet, lora_prefix_text_encoder, alpha)
|
152 |
+
|
153 |
+
pipe = pipe.to(args.device)
|
154 |
+
pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
|
utils/diffuser_utils.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Util functions based on Diffuser framework.
|
3 |
+
"""
|
4 |
+
|
5 |
+
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
import cv2
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
import torch.nn.functional as F
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
from torchvision.utils import save_image
|
15 |
+
from torchvision.io import read_image
|
16 |
+
|
17 |
+
from diffusers import StableDiffusionPipeline
|
18 |
+
|
19 |
+
from pytorch_lightning import seed_everything
|
20 |
+
|
21 |
+
|
22 |
+
class MasaCtrlPipeline(StableDiffusionPipeline):
|
23 |
+
|
24 |
+
def next_step(
|
25 |
+
self,
|
26 |
+
model_output: torch.FloatTensor,
|
27 |
+
timestep: int,
|
28 |
+
x: torch.FloatTensor,
|
29 |
+
eta=0.,
|
30 |
+
verbose=False
|
31 |
+
):
|
32 |
+
"""
|
33 |
+
Inverse sampling for DDIM Inversion
|
34 |
+
"""
|
35 |
+
if verbose:
|
36 |
+
print("timestep: ", timestep)
|
37 |
+
next_step = timestep
|
38 |
+
timestep = min(timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps, 999)
|
39 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep] if timestep >= 0 else self.scheduler.final_alpha_cumprod
|
40 |
+
alpha_prod_t_next = self.scheduler.alphas_cumprod[next_step]
|
41 |
+
beta_prod_t = 1 - alpha_prod_t
|
42 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
43 |
+
pred_dir = (1 - alpha_prod_t_next)**0.5 * model_output
|
44 |
+
x_next = alpha_prod_t_next**0.5 * pred_x0 + pred_dir
|
45 |
+
return x_next, pred_x0
|
46 |
+
|
47 |
+
def step(
|
48 |
+
self,
|
49 |
+
model_output: torch.FloatTensor,
|
50 |
+
timestep: int,
|
51 |
+
x: torch.FloatTensor,
|
52 |
+
eta: float=0.0,
|
53 |
+
verbose=False,
|
54 |
+
):
|
55 |
+
"""
|
56 |
+
predict the sampe the next step in the denoise process.
|
57 |
+
"""
|
58 |
+
prev_timestep = timestep - self.scheduler.config.num_train_timesteps // self.scheduler.num_inference_steps
|
59 |
+
alpha_prod_t = self.scheduler.alphas_cumprod[timestep]
|
60 |
+
alpha_prod_t_prev = self.scheduler.alphas_cumprod[prev_timestep] if prev_timestep > 0 else self.scheduler.final_alpha_cumprod
|
61 |
+
beta_prod_t = 1 - alpha_prod_t
|
62 |
+
pred_x0 = (x - beta_prod_t**0.5 * model_output) / alpha_prod_t**0.5
|
63 |
+
pred_dir = (1 - alpha_prod_t_prev)**0.5 * model_output
|
64 |
+
x_prev = alpha_prod_t_prev**0.5 * pred_x0 + pred_dir
|
65 |
+
return x_prev, pred_x0
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def image2latent(self, image):
|
69 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
70 |
+
if type(image) is Image:
|
71 |
+
image = np.array(image)
|
72 |
+
image = torch.from_numpy(image).float() / 127.5 - 1
|
73 |
+
image = image.permute(2, 0, 1).unsqueeze(0).to(DEVICE)
|
74 |
+
# input image density range [-1, 1]
|
75 |
+
latents = self.vae.encode(image)['latent_dist'].mean
|
76 |
+
latents = latents * 0.18215
|
77 |
+
return latents
|
78 |
+
|
79 |
+
@torch.no_grad()
|
80 |
+
def latent2image(self, latents, return_type='np'):
|
81 |
+
latents = 1 / 0.18215 * latents.detach()
|
82 |
+
image = self.vae.decode(latents)['sample']
|
83 |
+
if return_type == 'np':
|
84 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
85 |
+
image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
|
86 |
+
image = (image * 255).astype(np.uint8)
|
87 |
+
elif return_type == "pt":
|
88 |
+
image = (image / 2 + 0.5).clamp(0, 1)
|
89 |
+
|
90 |
+
return image
|
91 |
+
|
92 |
+
def latent2image_grad(self, latents):
|
93 |
+
latents = 1 / 0.18215 * latents
|
94 |
+
image = self.vae.decode(latents)['sample']
|
95 |
+
|
96 |
+
return image # range [-1, 1]
|
97 |
+
|
98 |
+
@torch.no_grad()
|
99 |
+
def __call__(
|
100 |
+
self,
|
101 |
+
prompt,
|
102 |
+
batch_size=1,
|
103 |
+
height=512,
|
104 |
+
width=512,
|
105 |
+
num_inference_steps=50,
|
106 |
+
guidance_scale=7.5,
|
107 |
+
eta=0.0,
|
108 |
+
latents=None,
|
109 |
+
unconditioning=None,
|
110 |
+
neg_prompt=None,
|
111 |
+
ref_intermediate_latents=None,
|
112 |
+
return_intermediates=False,
|
113 |
+
**kwds):
|
114 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
115 |
+
if isinstance(prompt, list):
|
116 |
+
batch_size = len(prompt)
|
117 |
+
elif isinstance(prompt, str):
|
118 |
+
if batch_size > 1:
|
119 |
+
prompt = [prompt] * batch_size
|
120 |
+
|
121 |
+
# text embeddings
|
122 |
+
text_input = self.tokenizer(
|
123 |
+
prompt,
|
124 |
+
padding="max_length",
|
125 |
+
max_length=77,
|
126 |
+
return_tensors="pt"
|
127 |
+
)
|
128 |
+
|
129 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
130 |
+
print("input text embeddings :", text_embeddings.shape)
|
131 |
+
if kwds.get("dir"):
|
132 |
+
dir = text_embeddings[-2] - text_embeddings[-1]
|
133 |
+
u, s, v = torch.pca_lowrank(dir.transpose(-1, -2), q=1, center=True)
|
134 |
+
text_embeddings[-1] = text_embeddings[-1] + kwds.get("dir") * v
|
135 |
+
print(u.shape)
|
136 |
+
print(v.shape)
|
137 |
+
|
138 |
+
# define initial latents
|
139 |
+
latents_shape = (batch_size, self.unet.config.in_channels, height//8, width//8)
|
140 |
+
if latents is None:
|
141 |
+
latents = torch.randn(latents_shape, device=DEVICE)
|
142 |
+
else:
|
143 |
+
assert latents.shape == latents_shape, f"The shape of input latent tensor {latents.shape} should equal to predefined one."
|
144 |
+
|
145 |
+
# unconditional embedding for classifier free guidance
|
146 |
+
if guidance_scale > 1.:
|
147 |
+
max_length = text_input.input_ids.shape[-1]
|
148 |
+
if neg_prompt:
|
149 |
+
uc_text = neg_prompt
|
150 |
+
else:
|
151 |
+
uc_text = ""
|
152 |
+
# uc_text = "ugly, tiling, poorly drawn hands, poorly drawn feet, body out of frame, cut off, low contrast, underexposed, distorted face"
|
153 |
+
unconditional_input = self.tokenizer(
|
154 |
+
[uc_text] * batch_size,
|
155 |
+
padding="max_length",
|
156 |
+
max_length=77,
|
157 |
+
return_tensors="pt"
|
158 |
+
)
|
159 |
+
# unconditional_input.input_ids = unconditional_input.input_ids[:, 1:]
|
160 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
161 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
|
162 |
+
|
163 |
+
print("latents shape: ", latents.shape)
|
164 |
+
# iterative sampling
|
165 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
166 |
+
# print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
167 |
+
latents_list = [latents]
|
168 |
+
pred_x0_list = [latents]
|
169 |
+
for i, t in enumerate(tqdm(self.scheduler.timesteps, desc="DDIM Sampler")):
|
170 |
+
if ref_intermediate_latents is not None:
|
171 |
+
# note that the batch_size >= 2
|
172 |
+
latents_ref = ref_intermediate_latents[-1 - i]
|
173 |
+
_, latents_cur = latents.chunk(2)
|
174 |
+
latents = torch.cat([latents_ref, latents_cur])
|
175 |
+
|
176 |
+
if guidance_scale > 1.:
|
177 |
+
model_inputs = torch.cat([latents] * 2)
|
178 |
+
else:
|
179 |
+
model_inputs = latents
|
180 |
+
if unconditioning is not None and isinstance(unconditioning, list):
|
181 |
+
_, text_embeddings = text_embeddings.chunk(2)
|
182 |
+
text_embeddings = torch.cat([unconditioning[i].expand(*text_embeddings.shape), text_embeddings])
|
183 |
+
# predict tghe noise
|
184 |
+
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
185 |
+
if guidance_scale > 1.:
|
186 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
187 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
188 |
+
# compute the previous noise sample x_t -> x_t-1
|
189 |
+
latents, pred_x0 = self.step(noise_pred, t, latents)
|
190 |
+
latents_list.append(latents)
|
191 |
+
pred_x0_list.append(pred_x0)
|
192 |
+
|
193 |
+
image = self.latent2image(latents, return_type="pt")
|
194 |
+
if return_intermediates:
|
195 |
+
pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
196 |
+
latents_list = [self.latent2image(img, return_type="pt") for img in latents_list]
|
197 |
+
return image, pred_x0_list, latents_list
|
198 |
+
return image
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def invert(
|
202 |
+
self,
|
203 |
+
image: torch.Tensor,
|
204 |
+
prompt,
|
205 |
+
num_inference_steps=50,
|
206 |
+
guidance_scale=7.5,
|
207 |
+
eta=0.0,
|
208 |
+
return_intermediates=False,
|
209 |
+
**kwds):
|
210 |
+
"""
|
211 |
+
invert a real image into noise map with determinisc DDIM inversion
|
212 |
+
"""
|
213 |
+
DEVICE = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
214 |
+
batch_size = image.shape[0]
|
215 |
+
if isinstance(prompt, list):
|
216 |
+
if batch_size == 1:
|
217 |
+
image = image.expand(len(prompt), -1, -1, -1)
|
218 |
+
elif isinstance(prompt, str):
|
219 |
+
if batch_size > 1:
|
220 |
+
prompt = [prompt] * batch_size
|
221 |
+
|
222 |
+
# text embeddings
|
223 |
+
text_input = self.tokenizer(
|
224 |
+
prompt,
|
225 |
+
padding="max_length",
|
226 |
+
max_length=77,
|
227 |
+
return_tensors="pt"
|
228 |
+
)
|
229 |
+
text_embeddings = self.text_encoder(text_input.input_ids.to(DEVICE))[0]
|
230 |
+
print("input text embeddings :", text_embeddings.shape)
|
231 |
+
# define initial latents
|
232 |
+
latents = self.image2latent(image)
|
233 |
+
start_latents = latents
|
234 |
+
# print(latents)
|
235 |
+
# exit()
|
236 |
+
# unconditional embedding for classifier free guidance
|
237 |
+
if guidance_scale > 1.:
|
238 |
+
max_length = text_input.input_ids.shape[-1]
|
239 |
+
unconditional_input = self.tokenizer(
|
240 |
+
[""] * batch_size,
|
241 |
+
padding="max_length",
|
242 |
+
max_length=77,
|
243 |
+
return_tensors="pt"
|
244 |
+
)
|
245 |
+
unconditional_embeddings = self.text_encoder(unconditional_input.input_ids.to(DEVICE))[0]
|
246 |
+
text_embeddings = torch.cat([unconditional_embeddings, text_embeddings], dim=0)
|
247 |
+
|
248 |
+
print("latents shape: ", latents.shape)
|
249 |
+
# interative sampling
|
250 |
+
self.scheduler.set_timesteps(num_inference_steps)
|
251 |
+
print("Valid timesteps: ", reversed(self.scheduler.timesteps))
|
252 |
+
# print("attributes: ", self.scheduler.__dict__)
|
253 |
+
latents_list = [latents]
|
254 |
+
pred_x0_list = [latents]
|
255 |
+
for i, t in enumerate(tqdm(reversed(self.scheduler.timesteps), desc="DDIM Inversion")):
|
256 |
+
if guidance_scale > 1.:
|
257 |
+
model_inputs = torch.cat([latents] * 2)
|
258 |
+
else:
|
259 |
+
model_inputs = latents
|
260 |
+
|
261 |
+
# predict the noise
|
262 |
+
noise_pred = self.unet(model_inputs, t, encoder_hidden_states=text_embeddings).sample
|
263 |
+
if guidance_scale > 1.:
|
264 |
+
noise_pred_uncon, noise_pred_con = noise_pred.chunk(2, dim=0)
|
265 |
+
noise_pred = noise_pred_uncon + guidance_scale * (noise_pred_con - noise_pred_uncon)
|
266 |
+
# compute the previous noise sample x_t-1 -> x_t
|
267 |
+
latents, pred_x0 = self.next_step(noise_pred, t, latents)
|
268 |
+
latents_list.append(latents)
|
269 |
+
pred_x0_list.append(pred_x0)
|
270 |
+
|
271 |
+
if return_intermediates:
|
272 |
+
# return the intermediate laters during inversion
|
273 |
+
# pred_x0_list = [self.latent2image(img, return_type="pt") for img in pred_x0_list]
|
274 |
+
return latents, latents_list
|
275 |
+
return latents, start_latents
|
utils/free_lunch_utils.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.fft as fft
|
3 |
+
from diffusers.models.unet_2d_condition import logger
|
4 |
+
from diffusers.utils import is_torch_version
|
5 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
def isinstance_str(x: object, cls_name: str):
|
9 |
+
"""
|
10 |
+
Checks whether x has any class *named* cls_name in its ancestry.
|
11 |
+
Doesn't require access to the class's implementation.
|
12 |
+
|
13 |
+
Useful for patching!
|
14 |
+
"""
|
15 |
+
|
16 |
+
for _cls in x.__class__.__mro__:
|
17 |
+
if _cls.__name__ == cls_name:
|
18 |
+
return True
|
19 |
+
|
20 |
+
return False
|
21 |
+
|
22 |
+
|
23 |
+
def Fourier_filter(x, threshold, scale):
|
24 |
+
dtype = x.dtype
|
25 |
+
x = x.type(torch.float32)
|
26 |
+
# FFT
|
27 |
+
x_freq = fft.fftn(x, dim=(-2, -1))
|
28 |
+
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
29 |
+
|
30 |
+
B, C, H, W = x_freq.shape
|
31 |
+
mask = torch.ones((B, C, H, W)).cuda()
|
32 |
+
|
33 |
+
crow, ccol = H // 2, W //2
|
34 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
35 |
+
x_freq = x_freq * mask
|
36 |
+
|
37 |
+
# IFFT
|
38 |
+
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
39 |
+
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
40 |
+
|
41 |
+
x_filtered = x_filtered.type(dtype)
|
42 |
+
return x_filtered
|
43 |
+
|
44 |
+
|
45 |
+
def register_upblock2d(model):
|
46 |
+
def up_forward(self):
|
47 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale=None):
|
48 |
+
for resnet in self.resnets:
|
49 |
+
# pop res hidden states
|
50 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
51 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
52 |
+
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
|
53 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
54 |
+
|
55 |
+
if self.training and self.gradient_checkpointing:
|
56 |
+
|
57 |
+
def create_custom_forward(module):
|
58 |
+
def custom_forward(*inputs):
|
59 |
+
return module(*inputs)
|
60 |
+
|
61 |
+
return custom_forward
|
62 |
+
|
63 |
+
if is_torch_version(">=", "1.11.0"):
|
64 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
65 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
66 |
+
)
|
67 |
+
else:
|
68 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
69 |
+
create_custom_forward(resnet), hidden_states, temb
|
70 |
+
)
|
71 |
+
else:
|
72 |
+
hidden_states = resnet(hidden_states, temb)
|
73 |
+
|
74 |
+
if self.upsamplers is not None:
|
75 |
+
for upsampler in self.upsamplers:
|
76 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
77 |
+
|
78 |
+
return hidden_states
|
79 |
+
|
80 |
+
return forward
|
81 |
+
|
82 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
83 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
84 |
+
upsample_block.forward = up_forward(upsample_block)
|
85 |
+
|
86 |
+
|
87 |
+
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2,source_mask=None):
|
88 |
+
def up_forward(self):
|
89 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None, scale=None):
|
90 |
+
for resnet in self.resnets:
|
91 |
+
# pop res hidden states
|
92 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
93 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
94 |
+
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
95 |
+
|
96 |
+
if self.source_mask is not None:
|
97 |
+
spatial_mask_source = F.interpolate(self.source_mask, (hidden_states.shape[2], hidden_states.shape[3]))
|
98 |
+
spatial_mask_source_b1 = spatial_mask_source * self.b1 + (1 - spatial_mask_source)
|
99 |
+
spatial_mask_source_b2 = spatial_mask_source * self.b2 + (1 - spatial_mask_source)
|
100 |
+
# --------------- FreeU code -----------------------
|
101 |
+
# Only operate on the first two stages
|
102 |
+
if hidden_states.shape[1] == 1280:
|
103 |
+
if self.source_mask is not None:
|
104 |
+
#where in mask = 0, set hidden states unchanged
|
105 |
+
hidden_states[:,:640] = hidden_states[:,:640] * spatial_mask_source_b1
|
106 |
+
|
107 |
+
else:
|
108 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
109 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
110 |
+
if hidden_states.shape[1] == 640:
|
111 |
+
|
112 |
+
if self.source_mask is not None:
|
113 |
+
hidden_states[:,:320] = hidden_states[:,:320] * spatial_mask_source_b2
|
114 |
+
else:
|
115 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
116 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
117 |
+
# ---------------------------------------------------------
|
118 |
+
|
119 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
120 |
+
|
121 |
+
if self.training and self.gradient_checkpointing:
|
122 |
+
|
123 |
+
def create_custom_forward(module):
|
124 |
+
def custom_forward(*inputs):
|
125 |
+
return module(*inputs)
|
126 |
+
|
127 |
+
return custom_forward
|
128 |
+
|
129 |
+
if is_torch_version(">=", "1.11.0"):
|
130 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
131 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
132 |
+
)
|
133 |
+
else:
|
134 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
135 |
+
create_custom_forward(resnet), hidden_states, temb
|
136 |
+
)
|
137 |
+
else:
|
138 |
+
hidden_states = resnet(hidden_states, temb)
|
139 |
+
|
140 |
+
if self.upsamplers is not None:
|
141 |
+
for upsampler in self.upsamplers:
|
142 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
143 |
+
|
144 |
+
return hidden_states
|
145 |
+
|
146 |
+
return forward
|
147 |
+
|
148 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
149 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
150 |
+
upsample_block.forward = up_forward(upsample_block)
|
151 |
+
setattr(upsample_block, 'b1', b1)
|
152 |
+
setattr(upsample_block, 'b2', b2)
|
153 |
+
setattr(upsample_block, 's1', s1)
|
154 |
+
setattr(upsample_block, 's2', s2)
|
155 |
+
setattr(upsample_block, 'source_mask', source_mask)
|
156 |
+
|
157 |
+
def register_crossattn_upblock2d(model):
|
158 |
+
def up_forward(self):
|
159 |
+
def forward(
|
160 |
+
hidden_states: torch.FloatTensor,
|
161 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
162 |
+
temb: Optional[torch.FloatTensor] = None,
|
163 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
164 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
165 |
+
upsample_size: Optional[int] = None,
|
166 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
167 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
168 |
+
):
|
169 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
170 |
+
# pop res hidden states
|
171 |
+
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
172 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
173 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
174 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
175 |
+
|
176 |
+
if self.training and self.gradient_checkpointing:
|
177 |
+
|
178 |
+
def create_custom_forward(module, return_dict=None):
|
179 |
+
def custom_forward(*inputs):
|
180 |
+
if return_dict is not None:
|
181 |
+
return module(*inputs, return_dict=return_dict)
|
182 |
+
else:
|
183 |
+
return module(*inputs)
|
184 |
+
|
185 |
+
return custom_forward
|
186 |
+
|
187 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
188 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
189 |
+
create_custom_forward(resnet),
|
190 |
+
hidden_states,
|
191 |
+
temb,
|
192 |
+
**ckpt_kwargs,
|
193 |
+
)
|
194 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
195 |
+
create_custom_forward(attn, return_dict=False),
|
196 |
+
hidden_states,
|
197 |
+
encoder_hidden_states,
|
198 |
+
None, # timestep
|
199 |
+
None, # class_labels
|
200 |
+
cross_attention_kwargs,
|
201 |
+
attention_mask,
|
202 |
+
encoder_attention_mask,
|
203 |
+
**ckpt_kwargs,
|
204 |
+
)[0]
|
205 |
+
else:
|
206 |
+
hidden_states = resnet(hidden_states, temb)
|
207 |
+
hidden_states = attn(
|
208 |
+
hidden_states,
|
209 |
+
encoder_hidden_states=encoder_hidden_states,
|
210 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
211 |
+
attention_mask=attention_mask,
|
212 |
+
encoder_attention_mask=encoder_attention_mask,
|
213 |
+
return_dict=False,
|
214 |
+
)[0]
|
215 |
+
|
216 |
+
if self.upsamplers is not None:
|
217 |
+
for upsampler in self.upsamplers:
|
218 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
219 |
+
|
220 |
+
return hidden_states
|
221 |
+
|
222 |
+
return forward
|
223 |
+
|
224 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
225 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
226 |
+
upsample_block.forward = up_forward(upsample_block)
|
227 |
+
|
228 |
+
|
229 |
+
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2,source_mask=None):
|
230 |
+
def up_forward(self):
|
231 |
+
def forward(
|
232 |
+
hidden_states: torch.FloatTensor,
|
233 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
234 |
+
temb: Optional[torch.FloatTensor] = None,
|
235 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
236 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
237 |
+
upsample_size: Optional[int] = None,
|
238 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
239 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
240 |
+
):
|
241 |
+
|
242 |
+
if self.source_mask is not None:
|
243 |
+
|
244 |
+
spatial_mask_source = F.interpolate(self.source_mask, (hidden_states.shape[2], hidden_states.shape[3]))
|
245 |
+
spatial_mask_source_b1 = spatial_mask_source * self.b1 + (1 - spatial_mask_source)
|
246 |
+
spatial_mask_source_b2 = spatial_mask_source * self.b2 + (1 - spatial_mask_source)
|
247 |
+
# print(f"source mask is not none, {spatial_mask_source_b1.shape} with min {spatial_mask_source_b1.min()}", )
|
248 |
+
|
249 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
250 |
+
# pop res hidden states
|
251 |
+
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
252 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
253 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
254 |
+
|
255 |
+
# --------------- FreeU code -----------------------
|
256 |
+
# Only operate on the first two stages
|
257 |
+
if hidden_states.shape[1] == 1280:
|
258 |
+
if self.source_mask is not None:
|
259 |
+
#where in mask = 0, set hidden states unchanged
|
260 |
+
hidden_states[:,:640] = hidden_states[:,:640] * spatial_mask_source_b1
|
261 |
+
|
262 |
+
else:
|
263 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
264 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
265 |
+
if hidden_states.shape[1] == 640:
|
266 |
+
if self.source_mask is not None:
|
267 |
+
hidden_states[:,:320] = hidden_states[:,:320] * spatial_mask_source_b2
|
268 |
+
else:
|
269 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
270 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
271 |
+
# ---------------------------------------------------------
|
272 |
+
|
273 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
274 |
+
|
275 |
+
if self.training and self.gradient_checkpointing:
|
276 |
+
|
277 |
+
def create_custom_forward(module, return_dict=None):
|
278 |
+
def custom_forward(*inputs):
|
279 |
+
if return_dict is not None:
|
280 |
+
return module(*inputs, return_dict=return_dict)
|
281 |
+
else:
|
282 |
+
return module(*inputs)
|
283 |
+
|
284 |
+
return custom_forward
|
285 |
+
|
286 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
287 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
288 |
+
create_custom_forward(resnet),
|
289 |
+
hidden_states,
|
290 |
+
temb,
|
291 |
+
**ckpt_kwargs,
|
292 |
+
)
|
293 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
294 |
+
create_custom_forward(attn, return_dict=False),
|
295 |
+
hidden_states,
|
296 |
+
encoder_hidden_states,
|
297 |
+
None, # timestep
|
298 |
+
None, # class_labels
|
299 |
+
cross_attention_kwargs,
|
300 |
+
attention_mask,
|
301 |
+
encoder_attention_mask,
|
302 |
+
**ckpt_kwargs,
|
303 |
+
)[0]
|
304 |
+
else:
|
305 |
+
hidden_states = resnet(hidden_states, temb)
|
306 |
+
# hidden_states = attn(
|
307 |
+
# hidden_states,
|
308 |
+
# encoder_hidden_states=encoder_hidden_states,
|
309 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
310 |
+
# encoder_attention_mask=encoder_attention_mask,
|
311 |
+
# return_dict=False,
|
312 |
+
# )[0]
|
313 |
+
hidden_states = attn(
|
314 |
+
hidden_states,
|
315 |
+
encoder_hidden_states=encoder_hidden_states,
|
316 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
317 |
+
)[0]
|
318 |
+
|
319 |
+
if self.upsamplers is not None:
|
320 |
+
for upsampler in self.upsamplers:
|
321 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
322 |
+
|
323 |
+
return hidden_states
|
324 |
+
|
325 |
+
return forward
|
326 |
+
|
327 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
328 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
329 |
+
upsample_block.forward = up_forward(upsample_block)
|
330 |
+
setattr(upsample_block, 'b1', b1)
|
331 |
+
setattr(upsample_block, 'b2', b2)
|
332 |
+
setattr(upsample_block, 's1', s1)
|
333 |
+
setattr(upsample_block, 's2', s2)
|
334 |
+
setattr(upsample_block, 'source_mask', source_mask)
|
utils/masactrl_utils.py
ADDED
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
from typing import Optional, Union, Tuple, List, Callable, Dict
|
9 |
+
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
from einops import rearrange, repeat
|
12 |
+
|
13 |
+
|
14 |
+
class AttentionBase:
|
15 |
+
def __init__(self):
|
16 |
+
self.cur_step = 0
|
17 |
+
self.num_att_layers = -1
|
18 |
+
self.cur_att_layer = 0
|
19 |
+
|
20 |
+
def after_step(self):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
24 |
+
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
25 |
+
self.cur_att_layer += 1
|
26 |
+
if self.cur_att_layer == self.num_att_layers:
|
27 |
+
self.cur_att_layer = 0
|
28 |
+
self.cur_step += 1
|
29 |
+
# after step
|
30 |
+
self.after_step()
|
31 |
+
return out
|
32 |
+
|
33 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
34 |
+
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
35 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
36 |
+
return out
|
37 |
+
|
38 |
+
def reset(self):
|
39 |
+
self.cur_step = 0
|
40 |
+
self.cur_att_layer = 0
|
41 |
+
|
42 |
+
|
43 |
+
class AttentionStore(AttentionBase):
|
44 |
+
def __init__(self, res=[32], min_step=0, max_step=1000):
|
45 |
+
super().__init__()
|
46 |
+
self.res = res
|
47 |
+
self.min_step = min_step
|
48 |
+
self.max_step = max_step
|
49 |
+
self.valid_steps = 0
|
50 |
+
|
51 |
+
self.self_attns = [] # store the all attns
|
52 |
+
self.cross_attns = []
|
53 |
+
|
54 |
+
self.self_attns_step = [] # store the attns in each step
|
55 |
+
self.cross_attns_step = []
|
56 |
+
|
57 |
+
def after_step(self):
|
58 |
+
if self.cur_step > self.min_step and self.cur_step < self.max_step:
|
59 |
+
self.valid_steps += 1
|
60 |
+
if len(self.self_attns) == 0:
|
61 |
+
self.self_attns = self.self_attns_step
|
62 |
+
self.cross_attns = self.cross_attns_step
|
63 |
+
else:
|
64 |
+
for i in range(len(self.self_attns)):
|
65 |
+
self.self_attns[i] += self.self_attns_step[i]
|
66 |
+
self.cross_attns[i] += self.cross_attns_step[i]
|
67 |
+
self.self_attns_step.clear()
|
68 |
+
self.cross_attns_step.clear()
|
69 |
+
|
70 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
71 |
+
if attn.shape[1] <= 64 ** 2: # avoid OOM
|
72 |
+
if is_cross:
|
73 |
+
self.cross_attns_step.append(attn)
|
74 |
+
else:
|
75 |
+
self.self_attns_step.append(attn)
|
76 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
77 |
+
|
78 |
+
|
79 |
+
def regiter_attention_editor_diffusers(model, editor: AttentionBase):
|
80 |
+
"""
|
81 |
+
Register a attention editor to Diffuser Pipeline, refer from [Prompt-to-Prompt]
|
82 |
+
"""
|
83 |
+
def ca_forward(self, place_in_unet):
|
84 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
|
85 |
+
"""
|
86 |
+
The attention is similar to the original implementation of LDM CrossAttention class
|
87 |
+
except adding some modifications on the attention
|
88 |
+
"""
|
89 |
+
if encoder_hidden_states is not None:
|
90 |
+
context = encoder_hidden_states
|
91 |
+
if attention_mask is not None:
|
92 |
+
mask = attention_mask
|
93 |
+
|
94 |
+
to_out = self.to_out
|
95 |
+
if isinstance(to_out, nn.modules.container.ModuleList):
|
96 |
+
to_out = self.to_out[0]
|
97 |
+
else:
|
98 |
+
to_out = self.to_out
|
99 |
+
|
100 |
+
h = self.heads
|
101 |
+
q = self.to_q(x)
|
102 |
+
is_cross = context is not None
|
103 |
+
context = context if is_cross else x
|
104 |
+
k = self.to_k(context)
|
105 |
+
v = self.to_v(context)
|
106 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
107 |
+
|
108 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
109 |
+
|
110 |
+
if mask is not None:
|
111 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
112 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
113 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
114 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
115 |
+
sim.masked_fill_(~mask, max_neg_value)
|
116 |
+
|
117 |
+
attn = sim.softmax(dim=-1)
|
118 |
+
# the only difference
|
119 |
+
out = editor(
|
120 |
+
q, k, v, sim, attn, is_cross, place_in_unet,
|
121 |
+
self.heads, scale=self.scale)
|
122 |
+
|
123 |
+
return to_out(out)
|
124 |
+
|
125 |
+
return forward
|
126 |
+
|
127 |
+
def register_editor(net, count, place_in_unet):
|
128 |
+
for name, subnet in net.named_children():
|
129 |
+
if net.__class__.__name__ == 'Attention': # spatial Transformer layer
|
130 |
+
net.forward = ca_forward(net, place_in_unet)
|
131 |
+
return count + 1
|
132 |
+
elif hasattr(net, 'children'):
|
133 |
+
count = register_editor(subnet, count, place_in_unet)
|
134 |
+
return count
|
135 |
+
|
136 |
+
cross_att_count = 0
|
137 |
+
for net_name, net in model.unet.named_children():
|
138 |
+
if "down" in net_name:
|
139 |
+
cross_att_count += register_editor(net, 0, "down")
|
140 |
+
elif "mid" in net_name:
|
141 |
+
cross_att_count += register_editor(net, 0, "mid")
|
142 |
+
elif "up" in net_name:
|
143 |
+
cross_att_count += register_editor(net, 0, "up")
|
144 |
+
editor.num_att_layers = cross_att_count
|
145 |
+
|
146 |
+
|
147 |
+
def regiter_attention_editor_ldm(model, editor: AttentionBase):
|
148 |
+
"""
|
149 |
+
Register a attention editor to Stable Diffusion model, refer from [Prompt-to-Prompt]
|
150 |
+
"""
|
151 |
+
def ca_forward(self, place_in_unet):
|
152 |
+
def forward(x, encoder_hidden_states=None, attention_mask=None, context=None, mask=None):
|
153 |
+
"""
|
154 |
+
The attention is similar to the original implementation of LDM CrossAttention class
|
155 |
+
except adding some modifications on the attention
|
156 |
+
"""
|
157 |
+
if encoder_hidden_states is not None:
|
158 |
+
context = encoder_hidden_states
|
159 |
+
if attention_mask is not None:
|
160 |
+
mask = attention_mask
|
161 |
+
|
162 |
+
to_out = self.to_out
|
163 |
+
if isinstance(to_out, nn.modules.container.ModuleList):
|
164 |
+
to_out = self.to_out[0]
|
165 |
+
else:
|
166 |
+
to_out = self.to_out
|
167 |
+
|
168 |
+
h = self.heads
|
169 |
+
q = self.to_q(x)
|
170 |
+
is_cross = context is not None
|
171 |
+
context = context if is_cross else x
|
172 |
+
k = self.to_k(context)
|
173 |
+
v = self.to_v(context)
|
174 |
+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
|
175 |
+
|
176 |
+
sim = torch.einsum('b i d, b j d -> b i j', q, k) * self.scale
|
177 |
+
|
178 |
+
if mask is not None:
|
179 |
+
mask = rearrange(mask, 'b ... -> b (...)')
|
180 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
181 |
+
mask = repeat(mask, 'b j -> (b h) () j', h=h)
|
182 |
+
mask = mask[:, None, :].repeat(h, 1, 1)
|
183 |
+
sim.masked_fill_(~mask, max_neg_value)
|
184 |
+
|
185 |
+
attn = sim.softmax(dim=-1)
|
186 |
+
# the only difference
|
187 |
+
out = editor(
|
188 |
+
q, k, v, sim, attn, is_cross, place_in_unet,
|
189 |
+
self.heads, scale=self.scale)
|
190 |
+
|
191 |
+
return to_out(out)
|
192 |
+
|
193 |
+
return forward
|
194 |
+
|
195 |
+
def register_editor(net, count, place_in_unet):
|
196 |
+
for name, subnet in net.named_children():
|
197 |
+
if net.__class__.__name__ == 'CrossAttention': # spatial Transformer layer
|
198 |
+
net.forward = ca_forward(net, place_in_unet)
|
199 |
+
return count + 1
|
200 |
+
elif hasattr(net, 'children'):
|
201 |
+
count = register_editor(subnet, count, place_in_unet)
|
202 |
+
return count
|
203 |
+
|
204 |
+
cross_att_count = 0
|
205 |
+
for net_name, net in model.model.diffusion_model.named_children():
|
206 |
+
if "input" in net_name:
|
207 |
+
cross_att_count += register_editor(net, 0, "input")
|
208 |
+
elif "middle" in net_name:
|
209 |
+
cross_att_count += register_editor(net, 0, "middle")
|
210 |
+
elif "output" in net_name:
|
211 |
+
cross_att_count += register_editor(net, 0, "output")
|
212 |
+
editor.num_att_layers = cross_att_count
|
utils/style_attn_control.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2023 The HuggingFace Team. All rights reserved.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from re import U
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
|
20 |
+
from einops import rearrange
|
21 |
+
|
22 |
+
from .masactrl_utils import AttentionBase
|
23 |
+
|
24 |
+
from torchvision.utils import save_image
|
25 |
+
|
26 |
+
import sys
|
27 |
+
|
28 |
+
import torch
|
29 |
+
import torch.nn.functional as F
|
30 |
+
from torch import nn
|
31 |
+
import torch.fft as fft
|
32 |
+
|
33 |
+
from einops import rearrange, repeat
|
34 |
+
from diffusers.utils import deprecate, logging
|
35 |
+
from diffusers.utils.import_utils import is_xformers_available
|
36 |
+
# from masactrl.masactrl import MutualSelfAttentionControl
|
37 |
+
|
38 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39 |
+
|
40 |
+
|
41 |
+
if is_xformers_available():
|
42 |
+
import xformers
|
43 |
+
import xformers.ops
|
44 |
+
else:
|
45 |
+
xformers = None
|
46 |
+
|
47 |
+
|
48 |
+
|
49 |
+
class AttentionBase:
|
50 |
+
def __init__(self):
|
51 |
+
self.cur_step = 0
|
52 |
+
self.num_att_layers = -1
|
53 |
+
self.cur_att_layer = 0
|
54 |
+
|
55 |
+
def after_step(self):
|
56 |
+
pass
|
57 |
+
|
58 |
+
def __call__(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
59 |
+
out = self.forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
60 |
+
self.cur_att_layer += 1
|
61 |
+
if self.cur_att_layer == self.num_att_layers:
|
62 |
+
self.cur_att_layer = 0
|
63 |
+
self.cur_step += 1
|
64 |
+
# after step
|
65 |
+
self.after_step()
|
66 |
+
return out
|
67 |
+
|
68 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
69 |
+
out = torch.einsum('b i j, b j d -> b i d', attn, v)
|
70 |
+
out = rearrange(out, '(b h) n d -> b n (h d)', h=num_heads)
|
71 |
+
return out
|
72 |
+
|
73 |
+
def reset(self):
|
74 |
+
self.cur_step = 0
|
75 |
+
self.cur_att_layer = 0
|
76 |
+
|
77 |
+
|
78 |
+
class MaskPromptedStyleAttentionControl(AttentionBase):
|
79 |
+
def __init__(self, start_step=4, start_layer=10, style_attn_step=35, layer_idx=None, step_idx=None, total_steps=50, style_guidance=0.1,
|
80 |
+
only_masked_region=False, guidance=0.0,
|
81 |
+
style_mask=None, source_mask=None, de_bug=False):
|
82 |
+
"""
|
83 |
+
MaskPromptedSAC
|
84 |
+
Args:
|
85 |
+
start_step: the step to start mutual self-attention control
|
86 |
+
start_layer: the layer to start mutual self-attention control
|
87 |
+
layer_idx: list of the layers to apply mutual self-attention control
|
88 |
+
step_idx: list the steps to apply mutual self-attention control
|
89 |
+
total_steps: the total number of steps
|
90 |
+
thres: the thereshold for mask thresholding
|
91 |
+
ref_token_idx: the token index list for cross-attention map aggregation
|
92 |
+
cur_token_idx: the token index list for cross-attention map aggregation
|
93 |
+
mask_save_dir: the path to save the mask image
|
94 |
+
"""
|
95 |
+
|
96 |
+
super().__init__()
|
97 |
+
self.total_steps = total_steps
|
98 |
+
self.total_layers = 16
|
99 |
+
self.start_step = start_step
|
100 |
+
self.start_layer = start_layer
|
101 |
+
self.layer_idx = layer_idx if layer_idx is not None else list(range(start_layer, self.total_layers))
|
102 |
+
self.step_idx = step_idx if step_idx is not None else list(range(start_step, total_steps))
|
103 |
+
print("using MaskPromptStyleAttentionControl")
|
104 |
+
print("MaskedSAC at denoising steps: ", self.step_idx)
|
105 |
+
print("MaskedSAC at U-Net layers: ", self.layer_idx)
|
106 |
+
|
107 |
+
self.de_bug = de_bug
|
108 |
+
self.style_guidance = style_guidance
|
109 |
+
self.only_masked_region = only_masked_region
|
110 |
+
self.style_attn_step = style_attn_step
|
111 |
+
self.self_attns = []
|
112 |
+
self.cross_attns = []
|
113 |
+
self.guidance = guidance
|
114 |
+
self.style_mask = style_mask
|
115 |
+
self.source_mask = source_mask
|
116 |
+
|
117 |
+
|
118 |
+
def after_step(self):
|
119 |
+
self.self_attns = []
|
120 |
+
self.cross_attns = []
|
121 |
+
|
122 |
+
def attn_batch(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
|
123 |
+
B = q.shape[0] // num_heads
|
124 |
+
H = W = int(np.sqrt(q.shape[1]))
|
125 |
+
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
|
126 |
+
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
|
127 |
+
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
|
128 |
+
|
129 |
+
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
|
130 |
+
|
131 |
+
if q_mask is not None:
|
132 |
+
sim = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
|
133 |
+
|
134 |
+
if k_mask is not None:
|
135 |
+
sim = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
|
136 |
+
|
137 |
+
attn = sim.softmax(-1) if attn is None else attn
|
138 |
+
|
139 |
+
if len(attn) == 2 * len(v):
|
140 |
+
v = torch.cat([v] * 2)
|
141 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
142 |
+
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
|
143 |
+
return out
|
144 |
+
|
145 |
+
def attn_batch_fg_bg(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, q_mask,k_mask, **kwargs):
|
146 |
+
B = q.shape[0] // num_heads
|
147 |
+
H = W = int(np.sqrt(q.shape[1]))
|
148 |
+
q = rearrange(q, "(b h) n d -> h (b n) d", h=num_heads)
|
149 |
+
k = rearrange(k, "(b h) n d -> h (b n) d", h=num_heads)
|
150 |
+
v = rearrange(v, "(b h) n d -> h (b n) d", h=num_heads)
|
151 |
+
sim = torch.einsum("h i d, h j d -> h i j", q, k) * kwargs.get("scale")
|
152 |
+
if q_mask is not None:
|
153 |
+
sim_fg = sim.masked_fill(q_mask.unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
|
154 |
+
sim_bg = sim.masked_fill(q_mask.unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
|
155 |
+
if k_mask is not None:
|
156 |
+
sim_fg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==0, -torch.finfo(sim.dtype).max)
|
157 |
+
sim_bg = sim.masked_fill(k_mask.permute(1,0).unsqueeze(0)==1, -torch.finfo(sim.dtype).max)
|
158 |
+
sim = torch.cat([sim_fg, sim_bg])
|
159 |
+
attn = sim.softmax(-1)
|
160 |
+
|
161 |
+
if len(attn) == 2 * len(v):
|
162 |
+
v = torch.cat([v] * 2)
|
163 |
+
out = torch.einsum("h i j, h j d -> h i d", attn, v)
|
164 |
+
out = rearrange(out, "(h1 h) (b n) d -> (h1 b) n (h d)", b=B, h=num_heads)
|
165 |
+
return out
|
166 |
+
|
167 |
+
def forward(self, q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs):
|
168 |
+
|
169 |
+
"""
|
170 |
+
Attention forward function
|
171 |
+
"""
|
172 |
+
|
173 |
+
if is_cross or self.cur_step not in self.step_idx or self.cur_att_layer // 2 not in self.layer_idx:
|
174 |
+
return super().forward(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, **kwargs)
|
175 |
+
|
176 |
+
B = q.shape[0] // num_heads // 2
|
177 |
+
H = W = int(np.sqrt(q.shape[1]))
|
178 |
+
|
179 |
+
if self.style_mask is not None and self.source_mask is not None:
|
180 |
+
#mask = self.aggregate_cross_attn_map(idx=self.cur_token_idx) # (4, H, W)
|
181 |
+
heigh, width = self.style_mask.shape[-2:]
|
182 |
+
mask_style = self.style_mask# (H, W)
|
183 |
+
mask_source = self.source_mask# (H, W)
|
184 |
+
scale = int(np.sqrt(heigh * width / q.shape[1]))
|
185 |
+
# res = int(np.sqrt(q.shape[1]))
|
186 |
+
spatial_mask_source = F.interpolate(mask_source, (heigh//scale, width//scale)).reshape(-1, 1)
|
187 |
+
spatial_mask_style = F.interpolate(mask_style, (heigh//scale, width//scale)).reshape(-1, 1)
|
188 |
+
|
189 |
+
else:
|
190 |
+
spatial_mask_source=None
|
191 |
+
spatial_mask_style=None
|
192 |
+
|
193 |
+
if spatial_mask_style is None or spatial_mask_source is None:
|
194 |
+
|
195 |
+
out_s,out_c,out_t = self.style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
|
196 |
+
|
197 |
+
else:
|
198 |
+
if self.only_masked_region:
|
199 |
+
out_s,out_c,out_t = self.mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
|
200 |
+
else:
|
201 |
+
out_s,out_c,out_t = self.separate_mask_prompted_style_attn_ctrl(q, k, v, sim, attn, is_cross, place_in_unet, num_heads, spatial_mask_source,spatial_mask_style,**kwargs)
|
202 |
+
|
203 |
+
out = torch.cat([out_s,out_c,out_t],dim=0)
|
204 |
+
return out
|
205 |
+
|
206 |
+
|
207 |
+
def style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
|
208 |
+
if self.de_bug:
|
209 |
+
import pdb; pdb.set_trace()
|
210 |
+
|
211 |
+
qs, qc, qt = q.chunk(3)
|
212 |
+
|
213 |
+
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
|
214 |
+
out_c = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
|
215 |
+
|
216 |
+
if self.cur_step < self.style_attn_step:
|
217 |
+
out_t = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
|
218 |
+
else:
|
219 |
+
out_t = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
|
220 |
+
if self.style_guidance>=0:
|
221 |
+
out_t = out_c + (out_t - out_c) * self.style_guidance
|
222 |
+
return out_s,out_c,out_t
|
223 |
+
|
224 |
+
def mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
|
225 |
+
qs, qc, qt = q.chunk(3)
|
226 |
+
|
227 |
+
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
|
228 |
+
out_c = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], attn[num_heads: 2*num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
|
229 |
+
out_c_new = self.attn_batch(qc, k[num_heads: 2*num_heads], v[num_heads:2*num_heads], sim[num_heads: 2*num_heads], None, is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None, **kwargs)
|
230 |
+
|
231 |
+
if self.de_bug:
|
232 |
+
import pdb; pdb.set_trace()
|
233 |
+
|
234 |
+
if self.cur_step < self.style_attn_step:
|
235 |
+
out_t = out_c #self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
|
236 |
+
else:
|
237 |
+
out_t_fg = self.attn_batch(qt, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
|
238 |
+
out_c_fg = self.attn_batch(qc, k[:num_heads], v[:num_heads], sim[:num_heads], None, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
|
239 |
+
if self.style_guidance>=0:
|
240 |
+
out_t = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance
|
241 |
+
|
242 |
+
out_t = out_t * spatial_mask_source + out_c * (1 - spatial_mask_source)
|
243 |
+
|
244 |
+
if self.de_bug:
|
245 |
+
import pdb; pdb.set_trace()
|
246 |
+
|
247 |
+
# print(torch.sum(out_t* (1 - spatial_mask_source) - out_c * (1 - spatial_mask_source)))
|
248 |
+
return out_s,out_c,out_t
|
249 |
+
|
250 |
+
def separate_mask_prompted_style_attn_ctrl(self,q,k,v,sim,attn,is_cross,place_in_unet,num_heads,spatial_mask_source,spatial_mask_style,**kwargs):
|
251 |
+
|
252 |
+
if self.de_bug:
|
253 |
+
import pdb; pdb.set_trace()
|
254 |
+
# To prevent query confusion, render fg and bg according to mask.
|
255 |
+
qs, qc, qt = q.chunk(3)
|
256 |
+
out_s = self.attn_batch(qs, k[:num_heads], v[:num_heads], sim[:num_heads], attn[:num_heads], is_cross, place_in_unet, num_heads, q_mask=None,k_mask=None,**kwargs)
|
257 |
+
if self.cur_step < self.style_attn_step:
|
258 |
+
|
259 |
+
out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
|
260 |
+
out_c_fg,out_c_bg = out_c.chunk(2)
|
261 |
+
out_t = out_c_fg * spatial_mask_source + out_c_bg * (1 - spatial_mask_source)
|
262 |
+
|
263 |
+
else:
|
264 |
+
out_t = self.attn_batch_fg_bg(qt, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
|
265 |
+
out_c = self.attn_batch_fg_bg(qc, k[:num_heads], v[:num_heads], sim[:num_heads], attn, is_cross, place_in_unet, num_heads, q_mask=spatial_mask_source,k_mask=spatial_mask_style,**kwargs)
|
266 |
+
out_t_fg,out_t_bg = out_t.chunk(2)
|
267 |
+
out_c_fg,out_c_bg = out_c.chunk(2)
|
268 |
+
if self.style_guidance>=0:
|
269 |
+
out_t_fg = out_c_fg + (out_t_fg - out_c_fg) * self.style_guidance
|
270 |
+
out_t_bg = out_c_bg + (out_t_bg - out_c_bg) * self.style_guidance
|
271 |
+
out_t = out_t_fg * spatial_mask_source + out_t_bg * (1 - spatial_mask_source)
|
272 |
+
|
273 |
+
return out_s,out_t,out_t
|
274 |
+
|
275 |
+
|