# teacache.py import torch import numpy as np from typing import Optional, Dict, Union, Any from functools import wraps class TeaCacheConfig: """Configuration for TeaCache acceleration""" def __init__( self, rel_l1_thresh: float = 0.15, enable: bool = True ): self.rel_l1_thresh = rel_l1_thresh self.enable = enable self._reset_state() def _reset_state(self): """Reset internal state""" self.cnt = 0 self.accumulated_rel_l1_distance = 0 self.previous_modulated_input = None self.previous_residual = None def create_teacache_forward(original_forward): """Factory function to create a TeaCache-enabled forward pass""" @wraps(original_forward) def teacache_forward( self, hidden_states: torch.Tensor, timestep: torch.Tensor, encoder_hidden_states: Optional[torch.Tensor] = None, encoder_attention_mask: Optional[torch.Tensor] = None, pooled_projections: Optional[torch.Tensor] = None, guidance: Optional[torch.Tensor] = None, attention_kwargs: Optional[Dict[str, Any]] = None, return_dict: bool = True, ): # Skip TeaCache if not enabled if not hasattr(self, 'teacache_config') or not self.teacache_config.enable: return original_forward( self, hidden_states=hidden_states, timestep=timestep, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, pooled_projections=pooled_projections, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=return_dict ) config = self.teacache_config # Prepare modulation vectors similar to HunyuanVideo implementation if pooled_projections is not None: vec = self.vector_in(pooled_projections) if guidance is not None: if vec is None: vec = self.guidance_in(guidance) else: vec = vec + self.guidance_in(guidance) # TeaCache optimization logic inp = hidden_states.clone() if hasattr(self.double_blocks[0], 'img_norm1'): # HunyuanVideo specific modulation img_mod1_shift, img_mod1_scale, _, _, _, _ = self.double_blocks[0].img_mod(vec).chunk(6, dim=-1) normed_inp = self.double_blocks[0].img_norm1(inp) modulated_inp = normed_inp * (1 + img_mod1_scale) + img_mod1_shift else: # Fallback modulation normed_inp = self.transformer_blocks[0].norm1(inp) modulated_inp = normed_inp # Determine if we should calculate or use cache should_calc = True if config.cnt == 0 or config.cnt == self.num_inference_steps - 1: should_calc = True config.accumulated_rel_l1_distance = 0 elif config.previous_modulated_input is not None: coefficients = [7.33226126e+02, -4.01131952e+02, 6.75869174e+01, -3.14987800e+00, 9.61237896e-02] rescale_func = np.poly1d(coefficients) rel_l1 = ((modulated_inp - config.previous_modulated_input).abs().mean() / config.previous_modulated_input.abs().mean()).cpu().item() config.accumulated_rel_l1_distance += rescale_func(rel_l1) should_calc = config.accumulated_rel_l1_distance >= config.rel_l1_thresh if should_calc: config.accumulated_rel_l1_distance = 0 config.previous_modulated_input = modulated_inp config.cnt += 1 if config.cnt >= self.num_inference_steps: config.cnt = 0 # Use cache or calculate new result if not should_calc and config.previous_residual is not None: hidden_states += config.previous_residual else: ori_hidden_states = hidden_states.clone() # Use original forward pass out = original_forward( self, hidden_states=hidden_states, timestep=timestep, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, pooled_projections=pooled_projections, guidance=guidance, attention_kwargs=attention_kwargs, return_dict=True ) hidden_states = out["sample"] # Store residual for future use config.previous_residual = hidden_states - ori_hidden_states if not return_dict: return (hidden_states,) return {"sample": hidden_states} return teacache_forward def enable_teacache(model: Any, num_inference_steps: int, rel_l1_thresh: float = 0.15): """Enable TeaCache acceleration for a model""" if not hasattr(model, '_original_forward'): model._original_forward = model.forward model.teacache_config = TeaCacheConfig(rel_l1_thresh=rel_l1_thresh) model.num_inference_steps = num_inference_steps model.forward = create_teacache_forward(model._original_forward).__get__(model) def disable_teacache(model: Any): """Disable TeaCache acceleration for a model""" if hasattr(model, '_original_forward'): model.forward = model._original_forward del model._original_forward if hasattr(model, 'teacache_config'): del model.teacache_config