|
import torch |
|
|
|
from modules_forge.supported_preprocessor import Preprocessor, PreprocessorParameter |
|
from modules_forge.shared import add_supported_preprocessor |
|
from ldm_patched.modules.samplers import sampling_function |
|
import ldm_patched.ldm.modules.attention as attention |
|
|
|
|
|
def sdp(q, k, v, transformer_options): |
|
if q.shape[0] == 0: |
|
return q |
|
|
|
return attention.optimized_attention(q, k, v, heads=transformer_options["n_heads"], mask=None) |
|
|
|
|
|
def adain(x, target_std, target_mean): |
|
if x.shape[0] == 0: |
|
return x |
|
|
|
std, mean = torch.std_mean(x, dim=(2, 3), keepdim=True, correction=0) |
|
return (((x - mean) / std) * target_std) + target_mean |
|
|
|
|
|
def zero_cat(a, b, dim): |
|
if a.shape[0] == 0: |
|
return b |
|
if b.shape[0] == 0: |
|
return a |
|
return torch.cat([a, b], dim=dim) |
|
|
|
|
|
class PreprocessorReference(Preprocessor): |
|
def __init__(self, name, use_attn=True, use_adain=True, priority=0): |
|
super().__init__() |
|
self.name = name |
|
self.use_attn = use_attn |
|
self.use_adain = use_adain |
|
self.sorting_priority = priority |
|
self.tags = ['Reference'] |
|
self.slider_resolution = PreprocessorParameter(visible=False) |
|
self.slider_1 = PreprocessorParameter(label='Style Fidelity', value=0.5, minimum=0.0, maximum=1.0, step=0.01, visible=True) |
|
self.show_control_mode = False |
|
self.corp_image_with_a1111_mask_when_in_img2img_inpaint_tab = False |
|
self.do_not_need_model = True |
|
|
|
self.is_recording_style = False |
|
self.recorded_attn1 = {} |
|
self.recorded_h = {} |
|
|
|
def process_before_every_sampling(self, process, cond, mask, *args, **kwargs): |
|
unit = kwargs['unit'] |
|
weight = float(unit.weight) |
|
style_fidelity = float(unit.threshold_a) |
|
start_percent = float(unit.guidance_start) |
|
end_percent = float(unit.guidance_end) |
|
|
|
if process.sd_model.is_sdxl: |
|
style_fidelity = style_fidelity ** 3.0 |
|
|
|
vae = process.sd_model.forge_objects.vae |
|
|
|
|
|
latent_image = vae.encode(cond.movedim(1, -1)) |
|
latent_image = process.sd_model.forge_objects.unet.model.latent_format.process_in(latent_image) |
|
|
|
gen_seed = process.seeds[0] + 1 |
|
gen_cpu = torch.Generator().manual_seed(gen_seed) |
|
|
|
unet = process.sd_model.forge_objects.unet.clone() |
|
sigma_max = unet.model.model_sampling.percent_to_sigma(start_percent) |
|
sigma_min = unet.model.model_sampling.percent_to_sigma(end_percent) |
|
|
|
self.recorded_attn1 = {} |
|
self.recorded_h = {} |
|
|
|
def conditioning_modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed): |
|
sigma = timestep[0].item() |
|
if not (sigma_min <= sigma <= sigma_max): |
|
return model, x, timestep, uncond, cond, cond_scale, model_options, seed |
|
|
|
self.is_recording_style = True |
|
|
|
xt = latent_image.to(x) + torch.randn(x.size(), dtype=x.dtype, generator=gen_cpu).to(x) * sigma |
|
sampling_function(model, xt, timestep, uncond, cond, 1, model_options, seed) |
|
|
|
self.is_recording_style = False |
|
|
|
return model, x, timestep, uncond, cond, cond_scale, model_options, seed |
|
|
|
def block_proc(h, flag, transformer_options): |
|
if not self.use_adain: |
|
return h |
|
|
|
if flag != 'after': |
|
return h |
|
|
|
location = transformer_options['block'] |
|
|
|
sigma = transformer_options["sigmas"][0].item() |
|
if not (sigma_min <= sigma <= sigma_max): |
|
return h |
|
|
|
channel = int(h.shape[1]) |
|
minimal_channel = 1500 - 1000 * weight |
|
|
|
if channel < minimal_channel: |
|
return h |
|
|
|
if self.is_recording_style: |
|
self.recorded_h[location] = torch.std_mean(h, dim=(2, 3), keepdim=True, correction=0) |
|
return h |
|
else: |
|
cond_indices = transformer_options['cond_indices'] |
|
uncond_indices = transformer_options['uncond_indices'] |
|
cond_or_uncond = transformer_options['cond_or_uncond'] |
|
r_std, r_mean = self.recorded_h[location] |
|
|
|
h_c = h[cond_indices] |
|
h_uc = h[uncond_indices] |
|
|
|
o_c = adain(h_c, r_std, r_mean) |
|
o_uc_strong = h_uc |
|
o_uc_weak = adain(h_uc, r_std, r_mean) |
|
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity |
|
|
|
recon = [] |
|
for cx in cond_or_uncond: |
|
if cx == 0: |
|
recon.append(o_c) |
|
else: |
|
recon.append(o_uc) |
|
|
|
o = torch.cat(recon, dim=0) |
|
return o |
|
|
|
def attn1_proc(q, k, v, transformer_options): |
|
if not self.use_attn: |
|
return sdp(q, k, v, transformer_options) |
|
|
|
sigma = transformer_options["sigmas"][0].item() |
|
if not (sigma_min <= sigma <= sigma_max): |
|
return sdp(q, k, v, transformer_options) |
|
|
|
location = (transformer_options['block'][0], transformer_options['block'][1], |
|
transformer_options['block_index']) |
|
|
|
channel = int(q.shape[2]) |
|
minimal_channel = 1500 - 1280 * weight |
|
|
|
if channel < minimal_channel: |
|
return sdp(q, k, v, transformer_options) |
|
|
|
if self.is_recording_style: |
|
self.recorded_attn1[location] = (k, v) |
|
return sdp(q, k, v, transformer_options) |
|
else: |
|
cond_indices = transformer_options['cond_indices'] |
|
uncond_indices = transformer_options['uncond_indices'] |
|
cond_or_uncond = transformer_options['cond_or_uncond'] |
|
|
|
q_c = q[cond_indices] |
|
q_uc = q[uncond_indices] |
|
|
|
k_c = k[cond_indices] |
|
k_uc = k[uncond_indices] |
|
|
|
v_c = v[cond_indices] |
|
v_uc = v[uncond_indices] |
|
|
|
k_r, v_r = self.recorded_attn1[location] |
|
|
|
o_c = sdp(q_c, zero_cat(k_c, k_r, dim=1), zero_cat(v_c, v_r, dim=1), transformer_options) |
|
o_uc_strong = sdp(q_uc, k_uc, v_uc, transformer_options) |
|
o_uc_weak = sdp(q_uc, zero_cat(k_uc, k_r, dim=1), zero_cat(v_uc, v_r, dim=1), transformer_options) |
|
o_uc = o_uc_weak + (o_uc_strong - o_uc_weak) * style_fidelity |
|
|
|
recon = [] |
|
for cx in cond_or_uncond: |
|
if cx == 0: |
|
recon.append(o_c) |
|
else: |
|
recon.append(o_uc) |
|
|
|
o = torch.cat(recon, dim=0) |
|
return o |
|
|
|
unet.add_block_modifier(block_proc) |
|
unet.add_conditioning_modifier(conditioning_modifier) |
|
unet.set_model_replace_all(attn1_proc, 'attn1') |
|
|
|
process.sd_model.forge_objects.unet = unet |
|
|
|
return cond, mask |
|
|
|
|
|
add_supported_preprocessor(PreprocessorReference( |
|
name='reference_only', |
|
use_attn=True, |
|
use_adain=False, |
|
priority=100 |
|
)) |
|
|
|
add_supported_preprocessor(PreprocessorReference( |
|
name='reference_adain', |
|
use_attn=False, |
|
use_adain=True |
|
)) |
|
|
|
add_supported_preprocessor(PreprocessorReference( |
|
name='reference_adain+attn', |
|
use_attn=True, |
|
use_adain=True |
|
)) |
|
|