Spaces:
Runtime error
Runtime error
from typing import Dict, List, TypedDict | |
import numpy as np | |
import math | |
import torch | |
from abc import ABC, abstractmethod | |
from diffusers.models.attention_processor import Attention as CrossAttention | |
from einops import rearrange | |
from ..Misc import Logger as log | |
from ..Misc.BBox import BoundingBox | |
KERNEL_DIVISION = 3. | |
INJECTION_SCALE = 1.0 | |
def reshape_fortran(x, shape): | |
""" Reshape a tensor in the fortran index. See | |
https://stackoverflow.com/a/63964246 | |
""" | |
if len(x.shape) > 0: | |
x = x.permute(*reversed(range(len(x.shape)))) | |
return x.reshape(*reversed(shape)).permute(*reversed(range(len(shape)))) | |
def gaussian_2d(x=0, y=0, mx=0, my=0, sx=1, sy=1): | |
""" 2d Gaussian weight function | |
""" | |
gaussian_map = ( | |
1 | |
/ (2 * math.pi * sx * sy) | |
* torch.exp(-((x - mx) ** 2 / (2 * sx**2) + (y - my) ** 2 / (2 * sy**2))) | |
) | |
gaussian_map.div_(gaussian_map.max()) | |
return gaussian_map | |
class BundleType(TypedDict): | |
selected_inds: List[int] # the 1-indexed indices of a subject | |
trailing_inds: List[int] # the 1-indexed indices of trailings | |
bbox: List[ | |
float | |
] # four floats to determine the bounding box [left, right, top, bottom] | |
class CrossAttnProcessorBase: | |
MAX_LEN_CLIP_TOKENS = 77 | |
DEVICE = "cuda" | |
def __init__(self, bundle, is_text2vidzero=False): | |
self.prompt = bundle["prompt_base"] | |
base_prompt = self.prompt.split(";")[0] | |
self.len_prompt = len(base_prompt.split(" ")) | |
self.prompt_len = len(self.prompt.split(" ")) | |
self.use_dd = False | |
self.use_dd_temporal = False | |
self.unet_chunk_size = 2 | |
self._cross_attention_map = None | |
self._loss = None | |
self._parameters = None | |
self.is_text2vidzero = is_text2vidzero | |
bbox = None | |
def cross_attention_map(self): | |
return self._cross_attention_map | |
def loss(self): | |
return self._loss | |
def parameters(self): | |
if type(self._parameters) == type(None): | |
log.warn("No parameters being initialized. Be cautious!") | |
return self._parameters | |
def __call__( | |
self, | |
attn: CrossAttention, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
): | |
batch_size, sequence_length, _ = hidden_states.shape | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
#print("====================") | |
query = attn.to_q(hidden_states) | |
is_cross_attention = encoder_hidden_states is not None | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
# elif attn.cross_attention_norm: | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_cross(encoder_hidden_states) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
def rearrange_3(tensor, f): | |
F, D, C = tensor.size() | |
return torch.reshape(tensor, (F // f, f, D, C)) | |
def rearrange_4(tensor): | |
B, F, D, C = tensor.size() | |
return torch.reshape(tensor, (B * F, D, C)) | |
# Cross Frame Attention | |
if not is_cross_attention and self.is_text2vidzero: | |
video_length = key.size()[0] // 2 | |
first_frame_index = [0] * video_length | |
# rearrange keys to have batch and frames in the 1st and 2nd dims respectively | |
key = rearrange_3(key, video_length) | |
key = key[:, first_frame_index] | |
# rearrange values to have batch and frames in the 1st and 2nd dims respectively | |
value = rearrange_3(value, video_length) | |
value = value[:, first_frame_index] | |
# rearrange back to original shape | |
key = rearrange_4(key) | |
value = rearrange_4(value) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
# Cross attention map | |
#print(query.shape, key.shape, value.shape) | |
attention_probs = attn.get_attention_scores(query, key) | |
# print(attention_probs.shape) | |
# torch.Size([960, 77, 64]) torch.Size([960, 256, 64]) torch.Size([960, 77, 64]) torch.Size([960, 256, 77]) | |
# torch.Size([10240, 24, 64]) torch.Size([10240, 24, 64]) torch.Size([10240, 24, 64]) torch.Size([10240, 24, 24]) | |
n = attention_probs.shape[0] // 2 | |
if attention_probs.shape[-1] == CrossAttnProcessorBase.MAX_LEN_CLIP_TOKENS: | |
dim = int(np.sqrt(attention_probs.shape[1])) | |
if self.use_dd: | |
# self.use_dd = False | |
attention_probs_4d = attention_probs.view( | |
attention_probs.shape[0], dim, dim, attention_probs.shape[-1] | |
)[n:] | |
attention_probs_4d = self.dd_core(attention_probs_4d) | |
attention_probs[n:] = attention_probs_4d.reshape( | |
attention_probs_4d.shape[0], dim * dim, attention_probs_4d.shape[-1] | |
) | |
self._cross_attention_map = attention_probs.view( | |
attention_probs.shape[0], dim, dim, attention_probs.shape[-1] | |
)[n:] | |
elif ( | |
attention_probs.shape[-1] == self.num_frames | |
and (attention_probs.shape[0] == 65536) | |
): | |
dim = int(np.sqrt(attention_probs.shape[0] // (2 * attn.heads))) | |
if self.use_dd_temporal: | |
# self.use_dd_temporal = False | |
def temporal_doit(origin_attn): | |
temporal_attn = reshape_fortran( | |
origin_attn, | |
(attn.heads, dim, dim, self.num_frames, self.num_frames), | |
) | |
temporal_attn = torch.transpose(temporal_attn, 1, 2) | |
temporal_attn = self.dd_core(temporal_attn) | |
# torch.Size([8, 64, 64, 24, 24]) | |
temporal_attn = torch.transpose(temporal_attn, 1, 2) | |
temporal_attn = reshape_fortran( | |
temporal_attn, | |
(attn.heads * dim * dim, self.num_frames, self.num_frames), | |
) | |
return temporal_attn | |
# NOTE: So null text embedding for classification free guidance | |
# doesn't really help? | |
#attention_probs[n:] = temporal_doit(attention_probs[n:]) | |
attention_probs[:n] = temporal_doit(attention_probs[:n]) | |
self._cross_attention_map = reshape_fortran( | |
attention_probs[:n], | |
(attn.heads, dim, dim, self.num_frames, self.num_frames), | |
) | |
self._cross_attention_map = self._cross_attention_map.mean(dim=0) | |
self._cross_attention_map = torch.transpose(self._cross_attention_map, 0, 1) | |
attention_probs = torch.abs(attention_probs) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
return hidden_states | |
def dd_core(self): | |
"""All DD variants implement this function""" | |
pass | |
def localized_weight_map(attention_probs_4d, token_inds, bbox_per_frame, scale=1): | |
"""Using guassian 2d distribution to generate weight map and return the | |
array with the same size of the attention argument. | |
""" | |
dim = int(attention_probs_4d.size()[1]) | |
max_val = attention_probs_4d.max() | |
weight_map = torch.zeros_like(attention_probs_4d).half() | |
frame_size = attention_probs_4d.shape[0] // len(bbox_per_frame) | |
for i in range(len(bbox_per_frame)): | |
bbox_ratios = bbox_per_frame[i] | |
bbox = BoundingBox(dim, bbox_ratios) | |
# Generating the gaussian distribution map patch | |
x = torch.linspace(0, bbox.height, bbox.height) | |
y = torch.linspace(0, bbox.width, bbox.width) | |
x, y = torch.meshgrid(x, y, indexing="ij") | |
noise_patch = ( | |
gaussian_2d( | |
x, | |
y, | |
mx=int(bbox.height / 2), | |
my=int(bbox.width / 2), | |
sx=float(bbox.height / KERNEL_DIVISION), | |
sy=float(bbox.width / KERNEL_DIVISION), | |
) | |
.unsqueeze(0) | |
.unsqueeze(-1) | |
.repeat(frame_size, 1, 1, len(token_inds)) | |
.to(attention_probs_4d.device) | |
).half() | |
scale = attention_probs_4d.max() * INJECTION_SCALE | |
noise_patch.mul_(scale) | |
b_idx = frame_size * i | |
e_idx = frame_size * (i + 1) | |
bbox.sliced_tensor_in_bbox(weight_map)[ | |
b_idx:e_idx, ..., token_inds | |
] = noise_patch | |
return weight_map | |
def localized_temporal_weight_map(attention_probs_5d, bbox_per_frame, scale=1): | |
"""Using guassian 2d distribution to generate weight map and return the | |
array with the same size of the attention argument. | |
""" | |
dim = int(attention_probs_5d.size()[1]) | |
f = attention_probs_5d.shape[-1] | |
max_val = attention_probs_5d.max() | |
weight_map = torch.zeros_like(attention_probs_5d).half() | |
def get_patch(bbox_at_frame, i, j, bbox_per_frame): | |
bbox = BoundingBox(dim, bbox_at_frame) | |
# Generating the gaussian distribution map patch | |
x = torch.linspace(0, bbox.height, bbox.height) | |
y = torch.linspace(0, bbox.width, bbox.width) | |
x, y = torch.meshgrid(x, y, indexing="ij") | |
noise_patch = ( | |
gaussian_2d( | |
x, | |
y, | |
mx=int(bbox.height / 2), | |
my=int(bbox.width / 2), | |
sx=float(bbox.height / KERNEL_DIVISION), | |
sy=float(bbox.width / KERNEL_DIVISION), | |
) | |
.unsqueeze(0) | |
.repeat(attention_probs_5d.shape[0], 1, 1) | |
.to(attention_probs_5d.device) | |
).half() | |
scale = attention_probs_5d.max() * INJECTION_SCALE | |
noise_patch.mul_(scale) | |
inv_noise_patch = noise_patch - noise_patch.max() | |
dist = (float(abs(j - i))) / len(bbox_per_frame) | |
final_patch = inv_noise_patch * dist + noise_patch * (1. - dist) | |
#final_patch = noise_patch * (1. - dist) | |
#final_patch = inv_noise_patch * dist | |
return final_patch, bbox | |
for j in range(len(bbox_per_frame)): | |
for i in range(len(bbox_per_frame)): | |
patch_i, bbox_i = get_patch(bbox_per_frame[i], i, j, bbox_per_frame) | |
patch_j, bbox_j = get_patch(bbox_per_frame[j], i, j, bbox_per_frame) | |
bbox_i.sliced_tensor_in_bbox(weight_map)[..., i, j] = patch_i | |
bbox_j.sliced_tensor_in_bbox(weight_map)[..., i, j] = patch_j | |
return weight_map | |