|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional, Union
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from torch import nn
|
|
|
|
from diffusers.utils.import_utils import is_xformers_available
|
|
|
|
|
|
if is_xformers_available():
|
|
import xformers
|
|
import xformers.ops
|
|
else:
|
|
xformers = None
|
|
|
|
|
|
class CrossAttention(nn.Module):
|
|
r"""
|
|
A cross attention layer.
|
|
|
|
Parameters:
|
|
query_dim (`int`): The number of channels in the query.
|
|
cross_attention_dim (`int`, *optional*):
|
|
The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
|
|
heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
|
|
dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
|
|
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
|
bias (`bool`, *optional*, defaults to False):
|
|
Set to `True` for the query, key, and value linear layers to contain a bias parameter.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
query_dim: int,
|
|
cross_attention_dim: Optional[int] = None,
|
|
heads: int = 8,
|
|
dim_head: int = 64,
|
|
dropout: float = 0.0,
|
|
bias=False,
|
|
upcast_attention: bool = False,
|
|
upcast_softmax: bool = False,
|
|
added_kv_proj_dim: Optional[int] = None,
|
|
norm_num_groups: Optional[int] = None,
|
|
processor: Optional["AttnProcessor"] = None,
|
|
):
|
|
super().__init__()
|
|
inner_dim = dim_head * heads
|
|
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
|
|
self.upcast_attention = upcast_attention
|
|
self.upcast_softmax = upcast_softmax
|
|
|
|
self.scale = dim_head**-0.5
|
|
|
|
self.heads = heads
|
|
|
|
|
|
|
|
self.sliceable_head_dim = heads
|
|
|
|
self.added_kv_proj_dim = added_kv_proj_dim
|
|
|
|
if norm_num_groups is not None:
|
|
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
|
|
else:
|
|
self.group_norm = None
|
|
|
|
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
|
|
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
|
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
|
|
|
|
if self.added_kv_proj_dim is not None:
|
|
self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
|
self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
|
|
|
|
self.to_out = nn.ModuleList([])
|
|
self.to_out.append(nn.Linear(inner_dim, query_dim))
|
|
self.to_out.append(nn.Dropout(dropout))
|
|
|
|
|
|
processor = processor if processor is not None else CrossAttnProcessor()
|
|
self.set_processor(processor)
|
|
|
|
def set_use_memory_efficient_attention_xformers(
|
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
|
):
|
|
if use_memory_efficient_attention_xformers:
|
|
if self.added_kv_proj_dim is not None:
|
|
|
|
|
|
|
|
raise NotImplementedError(
|
|
"Memory efficient attention with `xformers` is currently not supported when"
|
|
" `self.added_kv_proj_dim` is defined."
|
|
)
|
|
elif not is_xformers_available():
|
|
raise ModuleNotFoundError(
|
|
(
|
|
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
|
" xformers"
|
|
),
|
|
name="xformers",
|
|
)
|
|
elif not torch.cuda.is_available():
|
|
raise ValueError(
|
|
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is"
|
|
" only available for GPU "
|
|
)
|
|
else:
|
|
try:
|
|
|
|
_ = xformers.ops.memory_efficient_attention(
|
|
torch.randn((1, 2, 40), device="cuda"),
|
|
torch.randn((1, 2, 40), device="cuda"),
|
|
torch.randn((1, 2, 40), device="cuda"),
|
|
)
|
|
except Exception as e:
|
|
raise e
|
|
|
|
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
|
|
else:
|
|
processor = CrossAttnProcessor()
|
|
|
|
self.set_processor(processor)
|
|
|
|
def set_attention_slice(self, slice_size):
|
|
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
|
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
|
|
|
if slice_size is not None and self.added_kv_proj_dim is not None:
|
|
processor = SlicedAttnAddedKVProcessor(slice_size)
|
|
elif slice_size is not None:
|
|
processor = SlicedAttnProcessor(slice_size)
|
|
elif self.added_kv_proj_dim is not None:
|
|
processor = CrossAttnAddedKVProcessor()
|
|
else:
|
|
processor = CrossAttnProcessor()
|
|
|
|
self.set_processor(processor)
|
|
|
|
def set_processor(self, processor: "AttnProcessor"):
|
|
self.processor = processor
|
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
|
|
|
|
|
|
|
|
return self.processor(
|
|
self,
|
|
hidden_states,
|
|
encoder_hidden_states=encoder_hidden_states,
|
|
attention_mask=attention_mask,
|
|
**cross_attention_kwargs,
|
|
)
|
|
|
|
def batch_to_head_dim(self, tensor):
|
|
head_size = self.heads
|
|
batch_size, seq_len, dim = tensor.shape
|
|
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
|
return tensor
|
|
|
|
def head_to_batch_dim(self, tensor):
|
|
head_size = self.heads
|
|
batch_size, seq_len, dim = tensor.shape
|
|
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
|
|
return tensor
|
|
|
|
def get_attention_scores(self, query, key, attention_mask=None):
|
|
dtype = query.dtype
|
|
if self.upcast_attention:
|
|
query = query.float()
|
|
key = key.float()
|
|
|
|
attention_scores = torch.baddbmm(
|
|
torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
|
|
query,
|
|
key.transpose(-1, -2),
|
|
beta=0,
|
|
alpha=self.scale,
|
|
)
|
|
|
|
if attention_mask is not None:
|
|
attention_scores = attention_scores + attention_mask
|
|
|
|
if self.upcast_softmax:
|
|
attention_scores = attention_scores.float()
|
|
|
|
attention_probs = attention_scores.softmax(dim=-1)
|
|
attention_probs = attention_probs.to(dtype)
|
|
|
|
return attention_probs
|
|
|
|
def prepare_attention_mask(self, attention_mask, target_length):
|
|
head_size = self.heads
|
|
if attention_mask is None:
|
|
return attention_mask
|
|
|
|
if attention_mask.shape[-1] != target_length:
|
|
if attention_mask.device.type == "mps":
|
|
|
|
|
|
padding_shape = (attention_mask.shape[0], attention_mask.shape[1], target_length)
|
|
padding = torch.zeros(padding_shape, device=attention_mask.device)
|
|
attention_mask = torch.concat([attention_mask, padding], dim=2)
|
|
else:
|
|
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
|
|
attention_mask = attention_mask.repeat_interleave(head_size, dim=0)
|
|
return attention_mask
|
|
|
|
|
|
class CrossAttnProcessor:
|
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class LoRALinearLayer(nn.Module):
|
|
def __init__(self, in_features, out_features, rank=4):
|
|
super().__init__()
|
|
|
|
if rank > min(in_features, out_features):
|
|
raise ValueError(f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}")
|
|
|
|
self.down = nn.Linear(in_features, rank, bias=False)
|
|
self.up = nn.Linear(rank, out_features, bias=False)
|
|
self.scale = 1.0
|
|
|
|
nn.init.normal_(self.down.weight, std=1 / rank)
|
|
nn.init.zeros_(self.up.weight)
|
|
|
|
def forward(self, hidden_states):
|
|
orig_dtype = hidden_states.dtype
|
|
dtype = self.down.weight.dtype
|
|
|
|
down_hidden_states = self.down(hidden_states.to(dtype))
|
|
up_hidden_states = self.up(down_hidden_states)
|
|
|
|
return up_hidden_states.to(orig_dtype)
|
|
|
|
|
|
class LoRACrossAttnProcessor(nn.Module):
|
|
def __init__(self, hidden_size, cross_attention_dim=None, rank=4):
|
|
super().__init__()
|
|
|
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
|
|
|
def __call__(
|
|
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
|
):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
|
|
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class CrossAttnAddedKVProcessor:
|
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
|
|
|
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
|
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
|
|
|
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
hidden_states = torch.bmm(attention_probs, value)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
class XFormersCrossAttnProcessor:
|
|
def __init__(self, attention_op: Optional[Callable] = None):
|
|
self.attention_op = attention_op
|
|
|
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(
|
|
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
|
)
|
|
hidden_states = hidden_states.to(query.dtype)
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
return hidden_states
|
|
|
|
|
|
class LoRAXFormersCrossAttnProcessor(nn.Module):
|
|
def __init__(self, hidden_size, cross_attention_dim, rank=4):
|
|
super().__init__()
|
|
|
|
self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size)
|
|
self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
|
self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size)
|
|
self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size)
|
|
|
|
def __call__(
|
|
self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0
|
|
):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states)
|
|
query = attn.head_to_batch_dim(query).contiguous()
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
|
|
key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key).contiguous()
|
|
value = attn.head_to_batch_dim(value).contiguous()
|
|
|
|
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class SlicedAttnProcessor:
|
|
def __init__(self, slice_size):
|
|
self.slice_size = slice_size
|
|
|
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
dim = query.shape[-1]
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
|
|
key = attn.to_k(encoder_hidden_states)
|
|
value = attn.to_v(encoder_hidden_states)
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
|
|
batch_size_attention = query.shape[0]
|
|
hidden_states = torch.zeros(
|
|
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
|
)
|
|
|
|
for i in range(hidden_states.shape[0] // self.slice_size):
|
|
start_idx = i * self.slice_size
|
|
end_idx = (i + 1) * self.slice_size
|
|
|
|
query_slice = query[start_idx:end_idx]
|
|
key_slice = key[start_idx:end_idx]
|
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
|
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
|
|
hidden_states[start_idx:end_idx] = attn_slice
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
return hidden_states
|
|
|
|
|
|
class SlicedAttnAddedKVProcessor:
|
|
def __init__(self, slice_size):
|
|
self.slice_size = slice_size
|
|
|
|
def __call__(self, attn: "CrossAttention", hidden_states, encoder_hidden_states=None, attention_mask=None):
|
|
residual = hidden_states
|
|
hidden_states = hidden_states.view(hidden_states.shape[0], hidden_states.shape[1], -1).transpose(1, 2)
|
|
encoder_hidden_states = encoder_hidden_states.transpose(1, 2)
|
|
|
|
batch_size, sequence_length, _ = hidden_states.shape
|
|
|
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length)
|
|
|
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
|
|
query = attn.to_q(hidden_states)
|
|
dim = query.shape[-1]
|
|
query = attn.head_to_batch_dim(query)
|
|
|
|
key = attn.to_k(hidden_states)
|
|
value = attn.to_v(hidden_states)
|
|
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
|
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
|
|
|
key = attn.head_to_batch_dim(key)
|
|
value = attn.head_to_batch_dim(value)
|
|
encoder_hidden_states_key_proj = attn.head_to_batch_dim(encoder_hidden_states_key_proj)
|
|
encoder_hidden_states_value_proj = attn.head_to_batch_dim(encoder_hidden_states_value_proj)
|
|
|
|
key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
|
|
value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
|
|
|
|
batch_size_attention = query.shape[0]
|
|
hidden_states = torch.zeros(
|
|
(batch_size_attention, sequence_length, dim // attn.heads), device=query.device, dtype=query.dtype
|
|
)
|
|
|
|
for i in range(hidden_states.shape[0] // self.slice_size):
|
|
start_idx = i * self.slice_size
|
|
end_idx = (i + 1) * self.slice_size
|
|
|
|
query_slice = query[start_idx:end_idx]
|
|
key_slice = key[start_idx:end_idx]
|
|
attn_mask_slice = attention_mask[start_idx:end_idx] if attention_mask is not None else None
|
|
|
|
attn_slice = attn.get_attention_scores(query_slice, key_slice, attn_mask_slice)
|
|
|
|
attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
|
|
|
|
hidden_states[start_idx:end_idx] = attn_slice
|
|
|
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
|
|
|
|
hidden_states = attn.to_out[0](hidden_states)
|
|
|
|
hidden_states = attn.to_out[1](hidden_states)
|
|
|
|
hidden_states = hidden_states.transpose(-1, -2).reshape(residual.shape)
|
|
hidden_states = hidden_states + residual
|
|
|
|
return hidden_states
|
|
|
|
|
|
AttnProcessor = Union[
|
|
CrossAttnProcessor,
|
|
XFormersCrossAttnProcessor,
|
|
SlicedAttnProcessor,
|
|
CrossAttnAddedKVProcessor,
|
|
SlicedAttnAddedKVProcessor,
|
|
]
|
|
|