|
import torch |
|
from ldm_patched.modules.conds import CONDRegular, CONDCrossAttn |
|
from ldm_patched.modules.samplers import sampling_function |
|
from ldm_patched.modules import model_management |
|
from ldm_patched.modules.ops import cleanup_cache |
|
|
|
|
|
def cond_from_a1111_to_patched_ldm(cond): |
|
if isinstance(cond, torch.Tensor): |
|
result = dict( |
|
cross_attn=cond, |
|
model_conds=dict( |
|
c_crossattn=CONDCrossAttn(cond), |
|
) |
|
) |
|
return [result, ] |
|
|
|
cross_attn = cond['crossattn'] |
|
pooled_output = cond['vector'] |
|
|
|
result = dict( |
|
cross_attn=cross_attn, |
|
pooled_output=pooled_output, |
|
model_conds=dict( |
|
c_crossattn=CONDCrossAttn(cross_attn), |
|
y=CONDRegular(pooled_output) |
|
) |
|
) |
|
|
|
return [result, ] |
|
|
|
|
|
def cond_from_a1111_to_patched_ldm_weighted(cond, weights): |
|
transposed = list(map(list, zip(*weights))) |
|
results = [] |
|
|
|
for cond_pre in transposed: |
|
current_indices = [] |
|
current_weight = 0 |
|
for i, w in cond_pre: |
|
current_indices.append(i) |
|
current_weight = w |
|
|
|
if hasattr(cond, 'advanced_indexing'): |
|
feed = cond.advanced_indexing(current_indices) |
|
else: |
|
feed = cond[current_indices] |
|
|
|
h = cond_from_a1111_to_patched_ldm(feed) |
|
h[0]['strength'] = current_weight |
|
results += h |
|
|
|
return results |
|
|
|
|
|
def forge_sample(self, denoiser_params, cond_scale, cond_composition): |
|
model = self.inner_model.inner_model.forge_objects.unet.model |
|
control = self.inner_model.inner_model.forge_objects.unet.controlnet_linked_list |
|
extra_concat_condition = self.inner_model.inner_model.forge_objects.unet.extra_concat_condition |
|
x = denoiser_params.x |
|
timestep = denoiser_params.sigma |
|
uncond = cond_from_a1111_to_patched_ldm(denoiser_params.text_uncond) |
|
cond = cond_from_a1111_to_patched_ldm_weighted(denoiser_params.text_cond, cond_composition) |
|
model_options = self.inner_model.inner_model.forge_objects.unet.model_options |
|
seed = self.p.seeds[0] |
|
|
|
if extra_concat_condition is not None: |
|
image_cond_in = extra_concat_condition |
|
else: |
|
image_cond_in = denoiser_params.image_cond |
|
|
|
if isinstance(image_cond_in, torch.Tensor): |
|
if image_cond_in.shape[0] == x.shape[0] \ |
|
and image_cond_in.shape[2] == x.shape[2] \ |
|
and image_cond_in.shape[3] == x.shape[3]: |
|
for i in range(len(uncond)): |
|
uncond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) |
|
for i in range(len(cond)): |
|
cond[i]['model_conds']['c_concat'] = CONDRegular(image_cond_in) |
|
|
|
if control is not None: |
|
for h in cond + uncond: |
|
h['control'] = control |
|
|
|
for modifier in model_options.get('conditioning_modifiers', []): |
|
model, x, timestep, uncond, cond, cond_scale, model_options, seed = modifier(model, x, timestep, uncond, cond, cond_scale, model_options, seed) |
|
|
|
denoised = sampling_function(model, x, timestep, uncond, cond, cond_scale, model_options, seed) |
|
return denoised |
|
|
|
|
|
def sampling_prepare(unet, x): |
|
B, C, H, W = x.shape |
|
|
|
memory_estimation_function = unet.model_options.get('memory_peak_estimation_modifier', unet.memory_required) |
|
|
|
unet_inference_memory = memory_estimation_function([B * 2, C, H, W]) |
|
additional_inference_memory = unet.extra_preserved_memory_during_sampling |
|
additional_model_patchers = unet.extra_model_patchers_during_sampling |
|
|
|
if unet.controlnet_linked_list is not None: |
|
additional_inference_memory += unet.controlnet_linked_list.inference_memory_requirements(unet.model_dtype()) |
|
additional_model_patchers += unet.controlnet_linked_list.get_models() |
|
|
|
model_management.load_models_gpu( |
|
models=[unet] + additional_model_patchers, |
|
memory_required=unet_inference_memory + additional_inference_memory) |
|
|
|
real_model = unet.model |
|
|
|
percent_to_timestep_function = lambda p: real_model.model_sampling.percent_to_sigma(p) |
|
|
|
for cnet in unet.list_controlnets(): |
|
cnet.pre_run(real_model, percent_to_timestep_function) |
|
|
|
return |
|
|
|
|
|
def sampling_cleanup(unet): |
|
for cnet in unet.list_controlnets(): |
|
cnet.cleanup() |
|
cleanup_cache() |
|
return |
|
|