|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""该模型是自定义的attn_processor,实现特殊功能的 Attn功能。 |
|
相对而言,开源代码经常会重新定义Attention 类, |
|
|
|
This module implements special AttnProcessor function with custom attn_processor class. |
|
While other open source code always modify Attention class. |
|
""" |
|
|
|
from __future__ import annotations |
|
|
|
import time |
|
from typing import Any, Callable, Optional |
|
import logging |
|
|
|
from einops import rearrange, repeat |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import xformers |
|
from diffusers.models.lora import LoRACompatibleLinear |
|
|
|
from diffusers.utils.torch_utils import maybe_allow_in_graph |
|
from diffusers.models.attention_processor import ( |
|
Attention as DiffusersAttention, |
|
AttnProcessor, |
|
AttnProcessor2_0, |
|
) |
|
from ..data.data_util import ( |
|
batch_concat_two_tensor_with_index, |
|
batch_index_select, |
|
align_repeat_tensor_single_dim, |
|
batch_adain_conditioned_tensor, |
|
) |
|
|
|
from . import Model_Register |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
@maybe_allow_in_graph |
|
class IPAttention(DiffusersAttention): |
|
r""" |
|
Modified Attention class which has special layer, like ip_apadapter_to_k, ip_apadapter_to_v, |
|
""" |
|
|
|
def __init__( |
|
self, |
|
query_dim: int, |
|
cross_attention_dim: int | None = None, |
|
heads: int = 8, |
|
dim_head: int = 64, |
|
dropout: float = 0, |
|
bias=False, |
|
upcast_attention: bool = False, |
|
upcast_softmax: bool = False, |
|
cross_attention_norm: str | None = None, |
|
cross_attention_norm_num_groups: int = 32, |
|
added_kv_proj_dim: int | None = None, |
|
norm_num_groups: int | None = None, |
|
spatial_norm_dim: int | None = None, |
|
out_bias: bool = True, |
|
scale_qk: bool = True, |
|
only_cross_attention: bool = False, |
|
eps: float = 0.00001, |
|
rescale_output_factor: float = 1, |
|
residual_connection: bool = False, |
|
_from_deprecated_attn_block=False, |
|
processor: AttnProcessor | None = None, |
|
cross_attn_temporal_cond: bool = False, |
|
image_scale: float = 1.0, |
|
ip_adapter_dim: int = None, |
|
need_t2i_facein: bool = False, |
|
facein_dim: int = None, |
|
need_t2i_ip_adapter_face: bool = False, |
|
ip_adapter_face_dim: int = None, |
|
): |
|
super().__init__( |
|
query_dim, |
|
cross_attention_dim, |
|
heads, |
|
dim_head, |
|
dropout, |
|
bias, |
|
upcast_attention, |
|
upcast_softmax, |
|
cross_attention_norm, |
|
cross_attention_norm_num_groups, |
|
added_kv_proj_dim, |
|
norm_num_groups, |
|
spatial_norm_dim, |
|
out_bias, |
|
scale_qk, |
|
only_cross_attention, |
|
eps, |
|
rescale_output_factor, |
|
residual_connection, |
|
_from_deprecated_attn_block, |
|
processor, |
|
) |
|
self.cross_attn_temporal_cond = cross_attn_temporal_cond |
|
self.image_scale = image_scale |
|
|
|
|
|
if cross_attn_temporal_cond: |
|
self.to_k_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False) |
|
self.to_v_ip = LoRACompatibleLinear(ip_adapter_dim, query_dim, bias=False) |
|
|
|
self.need_t2i_facein = need_t2i_facein |
|
self.facein_dim = facein_dim |
|
if need_t2i_facein: |
|
raise NotImplementedError("facein") |
|
|
|
|
|
self.need_t2i_ip_adapter_face = need_t2i_ip_adapter_face |
|
self.ip_adapter_face_dim = ip_adapter_face_dim |
|
if need_t2i_ip_adapter_face: |
|
self.ip_adapter_face_to_k_ip = LoRACompatibleLinear( |
|
ip_adapter_face_dim, query_dim, bias=False |
|
) |
|
self.ip_adapter_face_to_v_ip = LoRACompatibleLinear( |
|
ip_adapter_face_dim, query_dim, bias=False |
|
) |
|
|
|
def set_use_memory_efficient_attention_xformers( |
|
self, |
|
use_memory_efficient_attention_xformers: bool, |
|
attention_op: Callable[..., Any] | None = None, |
|
): |
|
if ( |
|
"XFormers" in self.processor.__class__.__name__ |
|
or "IP" in self.processor.__class__.__name__ |
|
): |
|
pass |
|
else: |
|
return super().set_use_memory_efficient_attention_xformers( |
|
use_memory_efficient_attention_xformers, attention_op |
|
) |
|
|
|
|
|
@Model_Register.register |
|
class BaseIPAttnProcessor(nn.Module): |
|
print_idx = 0 |
|
|
|
def __init__(self, *args, **kwargs) -> None: |
|
super().__init__(*args, **kwargs) |
|
|
|
|
|
@Model_Register.register |
|
class T2IReferencenetIPAdapterXFormersAttnProcessor(BaseIPAttnProcessor): |
|
r""" |
|
面向 ref_image的 self_attn的 IPAdapter |
|
""" |
|
print_idx = 0 |
|
|
|
def __init__( |
|
self, |
|
attention_op: Optional[Callable] = None, |
|
): |
|
super().__init__() |
|
|
|
self.attention_op = attention_op |
|
|
|
def __call__( |
|
self, |
|
attn: IPAttention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
temb: Optional[torch.FloatTensor] = None, |
|
scale: float = 1.0, |
|
num_frames: int = None, |
|
sample_index: torch.LongTensor = None, |
|
vision_conditon_frames_sample_index: torch.LongTensor = None, |
|
refer_emb: torch.Tensor = None, |
|
vision_clip_emb: torch.Tensor = None, |
|
ip_adapter_scale: float = 1.0, |
|
face_emb: torch.Tensor = None, |
|
facein_scale: float = 1.0, |
|
ip_adapter_face_emb: torch.Tensor = None, |
|
ip_adapter_face_scale: float = 1.0, |
|
do_classifier_free_guidance: bool = False, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
|
|
batch_size, key_tokens, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
|
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, key_tokens, batch_size |
|
) |
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( |
|
1, 2 |
|
) |
|
|
|
query = attn.to_q(hidden_states, scale=scale) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states( |
|
encoder_hidden_states |
|
) |
|
encoder_hidden_states = align_repeat_tensor_single_dim( |
|
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 |
|
) |
|
key = attn.to_k(encoder_hidden_states, scale=scale) |
|
value = attn.to_v(encoder_hidden_states, scale=scale) |
|
|
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(face_emb)={type(face_emb)}, facein_scale={facein_scale}" |
|
) |
|
if facein_scale > 0 and face_emb is not None: |
|
raise NotImplementedError("facein") |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
hidden_states = xformers.ops.memory_efficient_attention( |
|
query, |
|
key, |
|
value, |
|
attn_bias=attention_mask, |
|
op=self.attention_op, |
|
scale=attn.scale, |
|
) |
|
|
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(vision_clip_emb)={type(vision_clip_emb)}" |
|
) |
|
if ip_adapter_scale > 0 and vision_clip_emb is not None: |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"T2I cross_attn, ipadapter, vision_clip_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}" |
|
) |
|
ip_key = attn.to_k_ip(vision_clip_emb) |
|
ip_value = attn.to_v_ip(vision_clip_emb) |
|
ip_key = align_repeat_tensor_single_dim( |
|
ip_key, target_length=batch_size, dim=0 |
|
) |
|
ip_value = align_repeat_tensor_single_dim( |
|
ip_value, target_length=batch_size, dim=0 |
|
) |
|
ip_key = attn.head_to_batch_dim(ip_key).contiguous() |
|
ip_value = attn.head_to_batch_dim(ip_value).contiguous() |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}" |
|
) |
|
|
|
hidden_states_from_ip = xformers.ops.memory_efficient_attention( |
|
query, |
|
ip_key, |
|
ip_value, |
|
attn_bias=attention_mask, |
|
op=self.attention_op, |
|
scale=attn.scale, |
|
) |
|
hidden_states = hidden_states + ip_adapter_scale * hidden_states_from_ip |
|
|
|
|
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"T2IReferencenetIPAdapterXFormersAttnProcessor,type(ip_adapter_face_emb)={type(ip_adapter_face_emb)}" |
|
) |
|
if ip_adapter_face_scale > 0 and ip_adapter_face_emb is not None: |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"T2I cross_attn, ipadapter face, ip_adapter_face_emb={vision_clip_emb.shape}, hidden_states={hidden_states.shape}, batch_size={batch_size}" |
|
) |
|
ip_key = attn.ip_adapter_face_to_k_ip(ip_adapter_face_emb) |
|
ip_value = attn.ip_adapter_face_to_v_ip(ip_adapter_face_emb) |
|
ip_key = align_repeat_tensor_single_dim( |
|
ip_key, target_length=batch_size, dim=0 |
|
) |
|
ip_value = align_repeat_tensor_single_dim( |
|
ip_value, target_length=batch_size, dim=0 |
|
) |
|
ip_key = attn.head_to_batch_dim(ip_key).contiguous() |
|
ip_value = attn.head_to_batch_dim(ip_value).contiguous() |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"query={query.shape}, ip_key={ip_key.shape}, ip_value={ip_value.shape}" |
|
) |
|
|
|
hidden_states_from_ip = xformers.ops.memory_efficient_attention( |
|
query, |
|
ip_key, |
|
ip_value, |
|
attn_bias=attention_mask, |
|
op=self.attention_op, |
|
scale=attn.scale, |
|
) |
|
hidden_states = ( |
|
hidden_states + ip_adapter_face_scale * hidden_states_from_ip |
|
) |
|
|
|
|
|
hidden_states = hidden_states.to(query.dtype) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, scale=scale) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
self.print_idx += 1 |
|
return hidden_states |
|
|
|
|
|
@Model_Register.register |
|
class NonParamT2ISelfReferenceXFormersAttnProcessor(BaseIPAttnProcessor): |
|
r""" |
|
面向首帧的 referenceonly attn,适用于 T2I的 self_attn |
|
referenceonly with vis_cond as key, value, in t2i self_attn. |
|
""" |
|
print_idx = 0 |
|
|
|
def __init__( |
|
self, |
|
attention_op: Optional[Callable] = None, |
|
): |
|
super().__init__() |
|
|
|
self.attention_op = attention_op |
|
|
|
def __call__( |
|
self, |
|
attn: IPAttention, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
temb: Optional[torch.FloatTensor] = None, |
|
scale: float = 1.0, |
|
num_frames: int = None, |
|
sample_index: torch.LongTensor = None, |
|
vision_conditon_frames_sample_index: torch.LongTensor = None, |
|
refer_emb: torch.Tensor = None, |
|
face_emb: torch.Tensor = None, |
|
vision_clip_emb: torch.Tensor = None, |
|
ip_adapter_scale: float = 1.0, |
|
facein_scale: float = 1.0, |
|
ip_adapter_face_emb: torch.Tensor = None, |
|
ip_adapter_face_scale: float = 1.0, |
|
do_classifier_free_guidance: bool = False, |
|
): |
|
residual = hidden_states |
|
|
|
if attn.spatial_norm is not None: |
|
hidden_states = attn.spatial_norm(hidden_states, temb) |
|
|
|
input_ndim = hidden_states.ndim |
|
|
|
if input_ndim == 4: |
|
batch_size, channel, height, width = hidden_states.shape |
|
hidden_states = hidden_states.view( |
|
batch_size, channel, height * width |
|
).transpose(1, 2) |
|
|
|
batch_size, key_tokens, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
|
|
attention_mask = attn.prepare_attention_mask( |
|
attention_mask, key_tokens, batch_size |
|
) |
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
|
|
if ( |
|
vision_conditon_frames_sample_index is not None and num_frames > 1 |
|
) or refer_emb is not None: |
|
batchsize_timesize = hidden_states.shape[0] |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"NonParamT2ISelfReferenceXFormersAttnProcessor 0, hidden_states={hidden_states.shape}, vision_conditon_frames_sample_index={vision_conditon_frames_sample_index}" |
|
) |
|
encoder_hidden_states = rearrange( |
|
hidden_states, "(b t) hw c -> b t hw c", t=num_frames |
|
) |
|
|
|
if vision_conditon_frames_sample_index is not None and num_frames > 1: |
|
ip_hidden_states = batch_index_select( |
|
encoder_hidden_states, |
|
dim=1, |
|
index=vision_conditon_frames_sample_index, |
|
).contiguous() |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"NonParamT2ISelfReferenceXFormersAttnProcessor 1, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" |
|
) |
|
|
|
ip_hidden_states = rearrange( |
|
ip_hidden_states, "b t hw c -> b 1 (t hw) c" |
|
) |
|
ip_hidden_states = align_repeat_tensor_single_dim( |
|
ip_hidden_states, |
|
dim=1, |
|
target_length=num_frames, |
|
) |
|
|
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"NonParamT2ISelfReferenceXFormersAttnProcessor 2, vis_cond referenceonly, encoder_hidden_states={encoder_hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" |
|
) |
|
encoder_hidden_states = torch.concat( |
|
[encoder_hidden_states, ip_hidden_states], dim=2 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"NonParamT2ISelfReferenceXFormersAttnProcessor 3, hidden_states={hidden_states.shape}, ip_hidden_states={ip_hidden_states.shape}" |
|
) |
|
|
|
if refer_emb is not None: |
|
refer_emb = rearrange(refer_emb, "b c t h w->b 1 (t h w) c") |
|
refer_emb = align_repeat_tensor_single_dim( |
|
refer_emb, target_length=num_frames, dim=1 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"NonParamT2ISelfReferenceXFormersAttnProcessor4, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}" |
|
) |
|
encoder_hidden_states = torch.concat( |
|
[encoder_hidden_states, refer_emb], dim=2 |
|
) |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"NonParamT2ISelfReferenceXFormersAttnProcessor5, referencenet, encoder_hidden_states={encoder_hidden_states.shape}, refer_emb={refer_emb.shape}" |
|
) |
|
encoder_hidden_states = rearrange( |
|
encoder_hidden_states, "b t hw c -> (b t) hw c" |
|
) |
|
|
|
|
|
if attn.group_norm is not None: |
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( |
|
1, 2 |
|
) |
|
|
|
query = attn.to_q(hidden_states, scale=scale) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif attn.norm_cross: |
|
encoder_hidden_states = attn.norm_encoder_hidden_states( |
|
encoder_hidden_states |
|
) |
|
encoder_hidden_states = align_repeat_tensor_single_dim( |
|
encoder_hidden_states, target_length=hidden_states.shape[0], dim=0 |
|
) |
|
key = attn.to_k(encoder_hidden_states, scale=scale) |
|
value = attn.to_v(encoder_hidden_states, scale=scale) |
|
|
|
query = attn.head_to_batch_dim(query).contiguous() |
|
key = attn.head_to_batch_dim(key).contiguous() |
|
value = attn.head_to_batch_dim(value).contiguous() |
|
|
|
hidden_states = xformers.ops.memory_efficient_attention( |
|
query, |
|
key, |
|
value, |
|
attn_bias=attention_mask, |
|
op=self.attention_op, |
|
scale=attn.scale, |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
hidden_states = attn.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states, scale=scale) |
|
|
|
hidden_states = attn.to_out[1](hidden_states) |
|
|
|
if input_ndim == 4: |
|
hidden_states = hidden_states.transpose(-1, -2).reshape( |
|
batch_size, channel, height, width |
|
) |
|
|
|
if attn.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / attn.rescale_output_factor |
|
self.print_idx += 1 |
|
|
|
return hidden_states |
|
|
|
|
|
@Model_Register.register |
|
class NonParamReferenceIPXFormersAttnProcessor( |
|
NonParamT2ISelfReferenceXFormersAttnProcessor |
|
): |
|
def __init__(self, attention_op: Callable[..., Any] | None = None): |
|
super().__init__(attention_op) |
|
|
|
|
|
@maybe_allow_in_graph |
|
class ReferEmbFuseAttention(IPAttention): |
|
"""使用 attention 融合 refernet 中的 emb 到 unet 对应的 latens 中 |
|
# TODO: 目前只支持 bt hw c 的融合,后续考虑增加对 视频 bhw t c、b thw c的融合 |
|
residual_connection: bool = True, 默认, 从不产生影响开始学习 |
|
|
|
use attention to fuse referencenet emb into unet latents |
|
# TODO: by now, only support bt hw c, later consider to support bhw t c, b thw c |
|
residual_connection: bool = True, default, start from no effect |
|
|
|
Args: |
|
IPAttention (_type_): _description_ |
|
""" |
|
|
|
print_idx = 0 |
|
|
|
def __init__( |
|
self, |
|
query_dim: int, |
|
cross_attention_dim: int | None = None, |
|
heads: int = 8, |
|
dim_head: int = 64, |
|
dropout: float = 0, |
|
bias=False, |
|
upcast_attention: bool = False, |
|
upcast_softmax: bool = False, |
|
cross_attention_norm: str | None = None, |
|
cross_attention_norm_num_groups: int = 32, |
|
added_kv_proj_dim: int | None = None, |
|
norm_num_groups: int | None = None, |
|
spatial_norm_dim: int | None = None, |
|
out_bias: bool = True, |
|
scale_qk: bool = True, |
|
only_cross_attention: bool = False, |
|
eps: float = 0.00001, |
|
rescale_output_factor: float = 1, |
|
residual_connection: bool = True, |
|
_from_deprecated_attn_block=False, |
|
processor: AttnProcessor | None = None, |
|
cross_attn_temporal_cond: bool = False, |
|
image_scale: float = 1, |
|
): |
|
super().__init__( |
|
query_dim, |
|
cross_attention_dim, |
|
heads, |
|
dim_head, |
|
dropout, |
|
bias, |
|
upcast_attention, |
|
upcast_softmax, |
|
cross_attention_norm, |
|
cross_attention_norm_num_groups, |
|
added_kv_proj_dim, |
|
norm_num_groups, |
|
spatial_norm_dim, |
|
out_bias, |
|
scale_qk, |
|
only_cross_attention, |
|
eps, |
|
rescale_output_factor, |
|
residual_connection, |
|
_from_deprecated_attn_block, |
|
processor, |
|
cross_attn_temporal_cond, |
|
image_scale, |
|
) |
|
self.processor = None |
|
|
|
nn.init.zeros_(self.to_out[0].weight) |
|
nn.init.zeros_(self.to_out[0].bias) |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.FloatTensor, |
|
encoder_hidden_states: Optional[torch.FloatTensor] = None, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
temb: Optional[torch.FloatTensor] = None, |
|
scale: float = 1.0, |
|
num_frames: int = None, |
|
) -> torch.Tensor: |
|
"""fuse referencenet emb b c t2 h2 w2 into unet latents b c t1 h1 w1 with attn |
|
refer to musev/models/attention_processor.py::NonParamT2ISelfReferenceXFormersAttnProcessor |
|
|
|
Args: |
|
hidden_states (torch.FloatTensor): unet latents, (b t1) c h1 w1 |
|
encoder_hidden_states (Optional[torch.FloatTensor], optional): referencenet emb b c2 t2 h2 w2. Defaults to None. |
|
attention_mask (Optional[torch.FloatTensor], optional): _description_. Defaults to None. |
|
temb (Optional[torch.FloatTensor], optional): _description_. Defaults to None. |
|
scale (float, optional): _description_. Defaults to 1.0. |
|
num_frames (int, optional): _description_. Defaults to None. |
|
|
|
Returns: |
|
torch.Tensor: _description_ |
|
""" |
|
residual = hidden_states |
|
|
|
hidden_states = rearrange( |
|
hidden_states, "(b t) c h w -> b c t h w", t=num_frames |
|
) |
|
batch_size, channel, t1, height, width = hidden_states.shape |
|
if self.print_idx == 0: |
|
logger.debug( |
|
f"hidden_states={hidden_states.shape},encoder_hidden_states={encoder_hidden_states.shape}" |
|
) |
|
|
|
encoder_hidden_states = rearrange( |
|
encoder_hidden_states, " b c t2 h w-> b (t2 h w) c" |
|
) |
|
encoder_hidden_states = repeat( |
|
encoder_hidden_states, " b t2hw c -> (b t) t2hw c", t=t1 |
|
) |
|
hidden_states = rearrange(hidden_states, " b c t h w-> (b t) (h w) c") |
|
|
|
encoder_hidden_states = torch.concat( |
|
[encoder_hidden_states, hidden_states], dim=1 |
|
) |
|
|
|
|
|
|
|
|
|
|
|
if self.spatial_norm is not None: |
|
hidden_states = self.spatial_norm(hidden_states, temb) |
|
|
|
_, key_tokens, _ = ( |
|
hidden_states.shape |
|
if encoder_hidden_states is None |
|
else encoder_hidden_states.shape |
|
) |
|
|
|
attention_mask = self.prepare_attention_mask( |
|
attention_mask, key_tokens, batch_size |
|
) |
|
if attention_mask is not None: |
|
|
|
|
|
|
|
|
|
|
|
|
|
_, query_tokens, _ = hidden_states.shape |
|
attention_mask = attention_mask.expand(-1, query_tokens, -1) |
|
|
|
if self.group_norm is not None: |
|
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose( |
|
1, 2 |
|
) |
|
|
|
query = self.to_q(hidden_states, scale=scale) |
|
|
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
elif self.norm_cross: |
|
encoder_hidden_states = self.norm_encoder_hidden_states( |
|
encoder_hidden_states |
|
) |
|
|
|
key = self.to_k(encoder_hidden_states, scale=scale) |
|
value = self.to_v(encoder_hidden_states, scale=scale) |
|
|
|
query = self.head_to_batch_dim(query).contiguous() |
|
key = self.head_to_batch_dim(key).contiguous() |
|
value = self.head_to_batch_dim(value).contiguous() |
|
|
|
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention( |
|
query, |
|
key, |
|
value, |
|
attn_bias=attention_mask, |
|
scale=self.scale, |
|
) |
|
hidden_states = hidden_states.to(query.dtype) |
|
hidden_states = self.batch_to_head_dim(hidden_states) |
|
|
|
|
|
hidden_states = self.to_out[0](hidden_states, scale=scale) |
|
|
|
hidden_states = self.to_out[1](hidden_states) |
|
|
|
hidden_states = rearrange( |
|
hidden_states, |
|
"bt (h w) c-> bt c h w", |
|
h=height, |
|
w=width, |
|
) |
|
if self.residual_connection: |
|
hidden_states = hidden_states + residual |
|
|
|
hidden_states = hidden_states / self.rescale_output_factor |
|
self.print_idx += 1 |
|
return hidden_states |
|
|