Spaces:
Running
on
Zero
Running
on
Zero
from __future__ import annotations | |
from typing import TYPE_CHECKING, Union | |
from comfy_api.latest import io, ComfyExtension | |
import comfy.patcher_extension | |
import logging | |
import torch | |
import comfy.model_patcher | |
if TYPE_CHECKING: | |
from uuid import UUID | |
def easycache_forward_wrapper(executor, *args, **kwargs): | |
# get values from args | |
x: torch.Tensor = args[0] | |
transformer_options: dict[str] = args[-1] | |
if not isinstance(transformer_options, dict): | |
transformer_options = kwargs.get("transformer_options") | |
if not transformer_options: | |
transformer_options = args[-2] | |
easycache: EasyCacheHolder = transformer_options["easycache"] | |
sigmas = transformer_options["sigmas"] | |
uuids = transformer_options["uuids"] | |
if sigmas is not None and easycache.is_past_end_timestep(sigmas): | |
return executor(*args, **kwargs) | |
# prepare next x_prev | |
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) | |
next_x_prev = x | |
input_change = None | |
do_easycache = easycache.should_do_easycache(sigmas) | |
if do_easycache: | |
easycache.check_metadata(x) | |
# if first cond marked this step for skipping, skip it and use appropriate cached values | |
if easycache.skip_current_step: | |
if easycache.verbose: | |
logging.info(f"EasyCache [verbose] - was marked to skip this step by {easycache.first_cond_uuid}. Present uuids: {uuids}") | |
return easycache.apply_cache_diff(x, uuids) | |
if easycache.initial_step: | |
easycache.first_cond_uuid = uuids[0] | |
has_first_cond_uuid = easycache.has_first_cond_uuid(uuids) | |
easycache.initial_step = False | |
if has_first_cond_uuid: | |
if easycache.has_x_prev_subsampled(): | |
input_change = (easycache.subsample(x, uuids, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() | |
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): | |
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm | |
easycache.cumulative_change_rate += approx_output_change_rate | |
if easycache.cumulative_change_rate < easycache.reuse_threshold: | |
if easycache.verbose: | |
logging.info(f"EasyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") | |
# other conds should also skip this step, and instead use their cached values | |
easycache.skip_current_step = True | |
return easycache.apply_cache_diff(x, uuids) | |
else: | |
if easycache.verbose: | |
logging.info(f"EasyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") | |
easycache.cumulative_change_rate = 0.0 | |
output: torch.Tensor = executor(*args, **kwargs) | |
if has_first_cond_uuid and easycache.has_output_prev_norm(): | |
output_change = (easycache.subsample(output, uuids, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() | |
if easycache.verbose: | |
output_change_rate = output_change / easycache.output_prev_norm | |
easycache.output_change_rates.append(output_change_rate.item()) | |
if easycache.has_relative_transformation_rate(): | |
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm | |
easycache.approx_output_change_rates.append(approx_output_change_rate.item()) | |
if easycache.verbose: | |
logging.info(f"EasyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") | |
if input_change is not None: | |
easycache.relative_transformation_rate = output_change / input_change | |
if easycache.verbose: | |
logging.info(f"EasyCache [verbose] - output_change_rate: {output_change_rate}") | |
# TODO: allow cache_diff to be offloaded | |
easycache.update_cache_diff(output, next_x_prev, uuids) | |
if has_first_cond_uuid: | |
easycache.x_prev_subsampled = easycache.subsample(next_x_prev, uuids) | |
easycache.output_prev_subsampled = easycache.subsample(output, uuids) | |
easycache.output_prev_norm = output.flatten().abs().mean() | |
if easycache.verbose: | |
logging.info(f"EasyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") | |
return output | |
def lazycache_predict_noise_wrapper(executor, *args, **kwargs): | |
# get values from args | |
x: torch.Tensor = args[0] | |
timestep: float = args[1] | |
model_options: dict[str] = args[2] | |
easycache: LazyCacheHolder = model_options["transformer_options"]["easycache"] | |
if easycache.is_past_end_timestep(timestep): | |
return executor(*args, **kwargs) | |
# prepare next x_prev | |
next_x_prev = x | |
input_change = None | |
do_easycache = easycache.should_do_easycache(timestep) | |
if do_easycache: | |
easycache.check_metadata(x) | |
if easycache.has_x_prev_subsampled(): | |
if easycache.has_x_prev_subsampled(): | |
input_change = (easycache.subsample(x, clone=False) - easycache.x_prev_subsampled).flatten().abs().mean() | |
if easycache.has_output_prev_norm() and easycache.has_relative_transformation_rate(): | |
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm | |
easycache.cumulative_change_rate += approx_output_change_rate | |
if easycache.cumulative_change_rate < easycache.reuse_threshold: | |
if easycache.verbose: | |
logging.info(f"LazyCache [verbose] - skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") | |
# other conds should also skip this step, and instead use their cached values | |
easycache.skip_current_step = True | |
return easycache.apply_cache_diff(x) | |
else: | |
if easycache.verbose: | |
logging.info(f"LazyCache [verbose] - NOT skipping step; cumulative_change_rate: {easycache.cumulative_change_rate}, reuse_threshold: {easycache.reuse_threshold}") | |
easycache.cumulative_change_rate = 0.0 | |
output: torch.Tensor = executor(*args, **kwargs) | |
if easycache.has_output_prev_norm(): | |
output_change = (easycache.subsample(output, clone=False) - easycache.output_prev_subsampled).flatten().abs().mean() | |
if easycache.verbose: | |
output_change_rate = output_change / easycache.output_prev_norm | |
easycache.output_change_rates.append(output_change_rate.item()) | |
if easycache.has_relative_transformation_rate(): | |
approx_output_change_rate = (easycache.relative_transformation_rate * input_change) / easycache.output_prev_norm | |
easycache.approx_output_change_rates.append(approx_output_change_rate.item()) | |
if easycache.verbose: | |
logging.info(f"LazyCache [verbose] - approx_output_change_rate: {approx_output_change_rate}") | |
if input_change is not None: | |
easycache.relative_transformation_rate = output_change / input_change | |
if easycache.verbose: | |
logging.info(f"LazyCache [verbose] - output_change_rate: {output_change_rate}") | |
# TODO: allow cache_diff to be offloaded | |
easycache.update_cache_diff(output, next_x_prev) | |
easycache.x_prev_subsampled = easycache.subsample(next_x_prev) | |
easycache.output_prev_subsampled = easycache.subsample(output) | |
easycache.output_prev_norm = output.flatten().abs().mean() | |
if easycache.verbose: | |
logging.info(f"LazyCache [verbose] - x_prev_subsampled: {easycache.x_prev_subsampled.shape}") | |
return output | |
def easycache_calc_cond_batch_wrapper(executor, *args, **kwargs): | |
model_options = args[-1] | |
easycache: EasyCacheHolder = model_options["transformer_options"]["easycache"] | |
easycache.skip_current_step = False | |
# TODO: check if first_cond_uuid is active at this timestep; otherwise, EasyCache needs to be partially reset | |
return executor(*args, **kwargs) | |
def easycache_sample_wrapper(executor, *args, **kwargs): | |
""" | |
This OUTER_SAMPLE wrapper makes sure easycache is prepped for current run, and all memory usage is cleared at the end. | |
""" | |
try: | |
guider = executor.class_obj | |
orig_model_options = guider.model_options | |
guider.model_options = comfy.model_patcher.create_model_options_clone(orig_model_options) | |
# clone and prepare timesteps | |
guider.model_options["transformer_options"]["easycache"] = guider.model_options["transformer_options"]["easycache"].clone().prepare_timesteps(guider.model_patcher.model.model_sampling) | |
easycache: Union[EasyCacheHolder, LazyCacheHolder] = guider.model_options['transformer_options']['easycache'] | |
logging.info(f"{easycache.name} enabled - threshold: {easycache.reuse_threshold}, start_percent: {easycache.start_percent}, end_percent: {easycache.end_percent}") | |
return executor(*args, **kwargs) | |
finally: | |
easycache = guider.model_options['transformer_options']['easycache'] | |
output_change_rates = easycache.output_change_rates | |
approx_output_change_rates = easycache.approx_output_change_rates | |
if easycache.verbose: | |
logging.info(f"{easycache.name} [verbose] - output_change_rates {len(output_change_rates)}: {output_change_rates}") | |
logging.info(f"{easycache.name} [verbose] - approx_output_change_rates {len(approx_output_change_rates)}: {approx_output_change_rates}") | |
total_steps = len(args[3])-1 | |
logging.info(f"{easycache.name} - skipped {easycache.total_steps_skipped}/{total_steps} steps ({total_steps/(total_steps-easycache.total_steps_skipped):.2f}x speedup).") | |
easycache.reset() | |
guider.model_options = orig_model_options | |
class EasyCacheHolder: | |
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): | |
self.name = "EasyCache" | |
self.reuse_threshold = reuse_threshold | |
self.start_percent = start_percent | |
self.end_percent = end_percent | |
self.subsample_factor = subsample_factor | |
self.offload_cache_diff = offload_cache_diff | |
self.verbose = verbose | |
# timestep values | |
self.start_t = 0.0 | |
self.end_t = 0.0 | |
# control values | |
self.relative_transformation_rate: float = None | |
self.cumulative_change_rate = 0.0 | |
self.initial_step = True | |
self.skip_current_step = False | |
# cache values | |
self.first_cond_uuid = None | |
self.x_prev_subsampled: torch.Tensor = None | |
self.output_prev_subsampled: torch.Tensor = None | |
self.output_prev_norm: torch.Tensor = None | |
self.uuid_cache_diffs: dict[UUID, torch.Tensor] = {} | |
self.output_change_rates = [] | |
self.approx_output_change_rates = [] | |
self.total_steps_skipped = 0 | |
# how to deal with mismatched dims | |
self.allow_mismatch = True | |
self.cut_from_start = True | |
self.state_metadata = None | |
def is_past_end_timestep(self, timestep: float) -> bool: | |
return not (timestep[0] > self.end_t).item() | |
def should_do_easycache(self, timestep: float) -> bool: | |
return (timestep[0] <= self.start_t).item() | |
def has_x_prev_subsampled(self) -> bool: | |
return self.x_prev_subsampled is not None | |
def has_output_prev_subsampled(self) -> bool: | |
return self.output_prev_subsampled is not None | |
def has_output_prev_norm(self) -> bool: | |
return self.output_prev_norm is not None | |
def has_relative_transformation_rate(self) -> bool: | |
return self.relative_transformation_rate is not None | |
def prepare_timesteps(self, model_sampling): | |
self.start_t = model_sampling.percent_to_sigma(self.start_percent) | |
self.end_t = model_sampling.percent_to_sigma(self.end_percent) | |
return self | |
def subsample(self, x: torch.Tensor, uuids: list[UUID], clone: bool = True) -> torch.Tensor: | |
batch_offset = x.shape[0] // len(uuids) | |
uuid_idx = uuids.index(self.first_cond_uuid) | |
if self.subsample_factor > 1: | |
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ..., ::self.subsample_factor, ::self.subsample_factor] | |
if clone: | |
return to_return.clone() | |
return to_return | |
to_return = x[uuid_idx*batch_offset:(uuid_idx+1)*batch_offset, ...] | |
if clone: | |
return to_return.clone() | |
return to_return | |
def apply_cache_diff(self, x: torch.Tensor, uuids: list[UUID]): | |
if self.first_cond_uuid in uuids: | |
self.total_steps_skipped += 1 | |
batch_offset = x.shape[0] // len(uuids) | |
for i, uuid in enumerate(uuids): | |
# if cached dims don't match x dims, cut off excess and hope for the best (cosmos world2video) | |
if x.shape[1:] != self.uuid_cache_diffs[uuid].shape[1:]: | |
if not self.allow_mismatch: | |
raise ValueError(f"Cached dims {self.uuid_cache_diffs[uuid].shape} don't match x dims {x.shape} - this is no good") | |
slicing = [] | |
skip_this_dim = True | |
for dim_u, dim_x in zip(self.uuid_cache_diffs[uuid].shape, x.shape): | |
if skip_this_dim: | |
skip_this_dim = False | |
continue | |
if dim_u != dim_x: | |
if self.cut_from_start: | |
slicing.append(slice(dim_x-dim_u, None)) | |
else: | |
slicing.append(slice(None, dim_u)) | |
else: | |
slicing.append(slice(None)) | |
slicing = [slice(i*batch_offset,(i+1)*batch_offset)] + slicing | |
x = x[slicing] | |
x += self.uuid_cache_diffs[uuid].to(x.device) | |
return x | |
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor, uuids: list[UUID]): | |
# if output dims don't match x dims, cut off excess and hope for the best (cosmos world2video) | |
if output.shape[1:] != x.shape[1:]: | |
if not self.allow_mismatch: | |
raise ValueError(f"Output dims {output.shape} don't match x dims {x.shape} - this is no good") | |
slicing = [] | |
skip_dim = True | |
for dim_o, dim_x in zip(output.shape, x.shape): | |
if not skip_dim and dim_o != dim_x: | |
if self.cut_from_start: | |
slicing.append(slice(dim_x-dim_o, None)) | |
else: | |
slicing.append(slice(None, dim_o)) | |
else: | |
slicing.append(slice(None)) | |
skip_dim = False | |
x = x[slicing] | |
diff = output - x | |
batch_offset = diff.shape[0] // len(uuids) | |
for i, uuid in enumerate(uuids): | |
self.uuid_cache_diffs[uuid] = diff[i*batch_offset:(i+1)*batch_offset, ...] | |
def has_first_cond_uuid(self, uuids: list[UUID]) -> bool: | |
return self.first_cond_uuid in uuids | |
def check_metadata(self, x: torch.Tensor) -> bool: | |
metadata = (x.device, x.dtype, x.shape[1:]) | |
if self.state_metadata is None: | |
self.state_metadata = metadata | |
return True | |
if metadata == self.state_metadata: | |
return True | |
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") | |
self.reset() | |
return False | |
def reset(self): | |
self.relative_transformation_rate = 0.0 | |
self.cumulative_change_rate = 0.0 | |
self.initial_step = True | |
self.skip_current_step = False | |
self.output_change_rates = [] | |
self.first_cond_uuid = None | |
del self.x_prev_subsampled | |
self.x_prev_subsampled = None | |
del self.output_prev_subsampled | |
self.output_prev_subsampled = None | |
del self.output_prev_norm | |
self.output_prev_norm = None | |
del self.uuid_cache_diffs | |
self.uuid_cache_diffs = {} | |
self.total_steps_skipped = 0 | |
self.state_metadata = None | |
return self | |
def clone(self): | |
return EasyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) | |
class EasyCacheNode(io.ComfyNode): | |
def define_schema(cls) -> io.Schema: | |
return io.Schema( | |
node_id="EasyCache", | |
display_name="EasyCache", | |
description="Native EasyCache implementation.", | |
category="advanced/debug/model", | |
is_experimental=True, | |
inputs=[ | |
io.Model.Input("model", tooltip="The model to add EasyCache to."), | |
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), | |
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of EasyCache."), | |
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of EasyCache."), | |
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), | |
], | |
outputs=[ | |
io.Model.Output(tooltip="The model with EasyCache."), | |
], | |
) | |
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: | |
model = model.clone() | |
model.model_options["transformer_options"]["easycache"] = EasyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) | |
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "easycache", easycache_sample_wrapper) | |
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.CALC_COND_BATCH, "easycache", easycache_calc_cond_batch_wrapper) | |
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, "easycache", easycache_forward_wrapper) | |
return io.NodeOutput(model) | |
class LazyCacheHolder: | |
def __init__(self, reuse_threshold: float, start_percent: float, end_percent: float, subsample_factor: int, offload_cache_diff: bool, verbose: bool=False): | |
self.name = "LazyCache" | |
self.reuse_threshold = reuse_threshold | |
self.start_percent = start_percent | |
self.end_percent = end_percent | |
self.subsample_factor = subsample_factor | |
self.offload_cache_diff = offload_cache_diff | |
self.verbose = verbose | |
# timestep values | |
self.start_t = 0.0 | |
self.end_t = 0.0 | |
# control values | |
self.relative_transformation_rate: float = None | |
self.cumulative_change_rate = 0.0 | |
self.initial_step = True | |
# cache values | |
self.x_prev_subsampled: torch.Tensor = None | |
self.output_prev_subsampled: torch.Tensor = None | |
self.output_prev_norm: torch.Tensor = None | |
self.cache_diff: torch.Tensor = None | |
self.output_change_rates = [] | |
self.approx_output_change_rates = [] | |
self.total_steps_skipped = 0 | |
self.state_metadata = None | |
def has_cache_diff(self) -> bool: | |
return self.cache_diff is not None | |
def is_past_end_timestep(self, timestep: float) -> bool: | |
return not (timestep[0] > self.end_t).item() | |
def should_do_easycache(self, timestep: float) -> bool: | |
return (timestep[0] <= self.start_t).item() | |
def has_x_prev_subsampled(self) -> bool: | |
return self.x_prev_subsampled is not None | |
def has_output_prev_subsampled(self) -> bool: | |
return self.output_prev_subsampled is not None | |
def has_output_prev_norm(self) -> bool: | |
return self.output_prev_norm is not None | |
def has_relative_transformation_rate(self) -> bool: | |
return self.relative_transformation_rate is not None | |
def prepare_timesteps(self, model_sampling): | |
self.start_t = model_sampling.percent_to_sigma(self.start_percent) | |
self.end_t = model_sampling.percent_to_sigma(self.end_percent) | |
return self | |
def subsample(self, x: torch.Tensor, clone: bool = True) -> torch.Tensor: | |
if self.subsample_factor > 1: | |
to_return = x[..., ::self.subsample_factor, ::self.subsample_factor] | |
if clone: | |
return to_return.clone() | |
return to_return | |
if clone: | |
return x.clone() | |
return x | |
def apply_cache_diff(self, x: torch.Tensor): | |
self.total_steps_skipped += 1 | |
return x + self.cache_diff.to(x.device) | |
def update_cache_diff(self, output: torch.Tensor, x: torch.Tensor): | |
self.cache_diff = output - x | |
def check_metadata(self, x: torch.Tensor) -> bool: | |
metadata = (x.device, x.dtype, x.shape) | |
if self.state_metadata is None: | |
self.state_metadata = metadata | |
return True | |
if metadata == self.state_metadata: | |
return True | |
logging.warn(f"{self.name} - Tensor shape, dtype or device changed, resetting state") | |
self.reset() | |
return False | |
def reset(self): | |
self.relative_transformation_rate = 0.0 | |
self.cumulative_change_rate = 0.0 | |
self.initial_step = True | |
self.output_change_rates = [] | |
self.approx_output_change_rates = [] | |
del self.cache_diff | |
self.cache_diff = None | |
del self.x_prev_subsampled | |
self.x_prev_subsampled = None | |
del self.output_prev_subsampled | |
self.output_prev_subsampled = None | |
del self.output_prev_norm | |
self.output_prev_norm = None | |
self.total_steps_skipped = 0 | |
self.state_metadata = None | |
return self | |
def clone(self): | |
return LazyCacheHolder(self.reuse_threshold, self.start_percent, self.end_percent, self.subsample_factor, self.offload_cache_diff, self.verbose) | |
class LazyCacheNode(io.ComfyNode): | |
def define_schema(cls) -> io.Schema: | |
return io.Schema( | |
node_id="LazyCache", | |
display_name="LazyCache", | |
description="A homebrew version of EasyCache - even 'easier' version of EasyCache to implement. Overall works worse than EasyCache, but better in some rare cases AND universal compatibility with everything in ComfyUI.", | |
category="advanced/debug/model", | |
is_experimental=True, | |
inputs=[ | |
io.Model.Input("model", tooltip="The model to add LazyCache to."), | |
io.Float.Input("reuse_threshold", min=0.0, default=0.2, max=3.0, step=0.01, tooltip="The threshold for reusing cached steps."), | |
io.Float.Input("start_percent", min=0.0, default=0.15, max=1.0, step=0.01, tooltip="The relative sampling step to begin use of LazyCache."), | |
io.Float.Input("end_percent", min=0.0, default=0.95, max=1.0, step=0.01, tooltip="The relative sampling step to end use of LazyCache."), | |
io.Boolean.Input("verbose", default=False, tooltip="Whether to log verbose information."), | |
], | |
outputs=[ | |
io.Model.Output(tooltip="The model with LazyCache."), | |
], | |
) | |
def execute(cls, model: io.Model.Type, reuse_threshold: float, start_percent: float, end_percent: float, verbose: bool) -> io.NodeOutput: | |
model = model.clone() | |
model.model_options["transformer_options"]["easycache"] = LazyCacheHolder(reuse_threshold, start_percent, end_percent, subsample_factor=8, offload_cache_diff=False, verbose=verbose) | |
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.OUTER_SAMPLE, "lazycache", easycache_sample_wrapper) | |
model.add_wrapper_with_key(comfy.patcher_extension.WrappersMP.PREDICT_NOISE, "lazycache", lazycache_predict_noise_wrapper) | |
return io.NodeOutput(model) | |
class EasyCacheExtension(ComfyExtension): | |
async def get_node_list(self) -> list[type[io.ComfyNode]]: | |
return [ | |
EasyCacheNode, | |
LazyCacheNode, | |
] | |
def comfy_entrypoint(): | |
return EasyCacheExtension() | |