Spaces:
Runtime error
Runtime error
from typing import Dict, List, TypedDict | |
import numpy as np | |
import torch | |
import math | |
from ..Misc import Logger as log | |
from .BaseProc import CrossAttnProcessorBase | |
from .BaseProc import BundleType | |
from ..Misc.BBox import BoundingBox | |
class InjecterProcessor(CrossAttnProcessorBase): | |
def __init__( | |
self, | |
bundle: BundleType, | |
bbox_per_frame: List[BoundingBox], | |
name: str, | |
strengthen_scale: float = 0.0, | |
weaken_scale: float = 1.0, | |
is_text2vidzero: bool = False, | |
): | |
super().__init__(bundle, is_text2vidzero=is_text2vidzero) | |
self.strengthen_scale = strengthen_scale | |
self.weaken_scale = weaken_scale | |
self.bundle = bundle | |
self.num_frames = len(bbox_per_frame) | |
self.bbox_per_frame = bbox_per_frame | |
self.use_weaken = True | |
self.name = name | |
def dd_core(self, attention_probs: torch.Tensor): | |
""" """ | |
frame_size = attention_probs.shape[0] // self.num_frames | |
num_affected_frames = self.num_frames | |
attention_probs_copied = attention_probs.detach().clone() | |
token_inds = self.bundle.get("token_inds") | |
trailing_length = self.bundle.get("trailing_length") | |
trailing_inds = list( | |
range(self.len_prompt + 1, self.len_prompt + trailing_length + 1) | |
) | |
# NOTE: Spatial cross attention editing | |
if len(attention_probs.size()) == 4: | |
all_tokens_inds = list(set(token_inds).union(set(trailing_inds))) | |
strengthen_map = self.localized_weight_map( | |
attention_probs_copied, | |
token_inds=all_tokens_inds, | |
bbox_per_frame=self.bbox_per_frame, | |
) | |
weaken_map = torch.ones_like(strengthen_map) | |
zero_indices = torch.where(strengthen_map == 0) | |
weaken_map[zero_indices] = self.weaken_scale | |
# weakening | |
attention_probs_copied[..., all_tokens_inds] *= weaken_map[ | |
..., all_tokens_inds | |
] | |
# strengthen | |
attention_probs_copied[..., all_tokens_inds] += ( | |
self.strengthen_scale * strengthen_map[..., all_tokens_inds] | |
) | |
# NOTE: Temporal cross attention editing | |
elif len(attention_probs.size()) == 5: | |
strengthen_map = self.localized_temporal_weight_map( | |
attention_probs_copied, | |
bbox_per_frame=self.bbox_per_frame, | |
) | |
weaken_map = torch.ones_like(strengthen_map) | |
zero_indices = torch.where(strengthen_map == 0) | |
weaken_map[zero_indices] = self.weaken_scale | |
# weakening | |
attention_probs_copied *= weaken_map | |
# strengthen | |
attention_probs_copied += self.strengthen_scale * strengthen_map | |
return attention_probs_copied | |