Spaces:
Running
on
Zero
Running
on
Zero
# modified from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
# Global Variable | |
global_concept_mask = [] | |
attn_mask_logs = {} | |
text_attn_map_logs = {} | |
image_attn_map_logs = {} | |
class AttnProcessor(nn.Module): | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __init__( | |
self, | |
hidden_size=None, | |
cross_attention_dim=None, | |
): | |
super().__init__() | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class IPAttnProcessor(nn.Module): | |
r""" | |
Attention processor for IP-Adapater. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
The context length of the image features. | |
""" | |
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.scale = scale | |
self.num_tokens = num_tokens | |
self.to_k_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
self.to_v_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
global global_concept_mask | |
global attn_mask_logs | |
global text_attn_map_logs | |
global image_attn_map_logs | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
else: | |
# get encoder_hidden_states, ip_hidden_states | |
end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
encoder_hidden_states, ip_hidden_states = ( | |
encoder_hidden_states[:, :end_pos, :], | |
encoder_hidden_states[:, end_pos:, :], | |
) | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# for ip-adapter | |
ip_key = self.to_k_ip(ip_hidden_states) | |
ip_value = self.to_v_ip(ip_hidden_states) | |
ip_key = attn.head_to_batch_dim(ip_key) | |
ip_value = attn.head_to_batch_dim(ip_value) | |
ip_attention_probs = attn.get_attention_scores(query, ip_key, None) | |
self.attn_map = ip_attention_probs | |
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) | |
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states) | |
hidden_states = hidden_states + self.scale * ip_hidden_states | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class ConceptrolAttnProcessor(nn.Module): | |
r""" | |
Attention processor for IP-Adapater. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
The context length of the image features. | |
""" | |
def __init__( | |
self, | |
hidden_size, | |
cross_attention_dim=None, | |
scale=1.0, | |
num_tokens=4, | |
textual_concept_idxs=None, | |
name=None, | |
global_masking=False, | |
adaptive_scale_mask=False, | |
concept_mask_layer=None, | |
): | |
super().__init__() | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.scale = scale | |
self.num_tokens = num_tokens | |
self.textual_concept_idxs = textual_concept_idxs | |
self.name = name | |
self.to_k_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
self.to_v_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
self.global_masking = global_masking | |
self.adaptive_scale_mask = adaptive_scale_mask | |
if concept_mask_layer is None: | |
concept_mask_layer = [ | |
"mid_block.attentions.0.transformer_blocks.0.attn2.processor" | |
] # For SD | |
print("Warning: Using default concept mask layer for SD. For SDXL, use 'up_blocks.0.attentions.1.transformer_blocks.5.attn2.processor'") | |
# concept_mask_layer = ['up_blocks.0.attentions.1.transformer_blocks.1.attn2.processor'] # For SDXL | |
self.concept_mask_layer = concept_mask_layer | |
def set_global_view(self, attn_procs): | |
self.attn_procs = attn_procs | |
# print(self.name, self.attn_procs.keys()) | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
global global_concept_mask | |
global attn_mask_logs | |
if self.textual_concept_idxs is None: | |
raise ValueError( | |
"textual_concept_idxs should be provided for ConceptrolAttnProcessor" | |
) | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
else: | |
# get encoder_hidden_states, ip_hidden_states | |
end_pos = 77 # Both SD and SDXL use 77 as length of text tokens | |
encoder_hidden_states, ip_hidden_states_cat = ( | |
encoder_hidden_states[:, :end_pos, :], | |
encoder_hidden_states[:, end_pos:, :], | |
) | |
num_concepts = ip_hidden_states_cat.shape[1] // self.num_tokens | |
ip_hidden_states_list = torch.chunk( | |
ip_hidden_states_cat, num_concepts, dim=1 | |
) | |
assert len(ip_hidden_states_list) == len( | |
self.textual_concept_idxs | |
), f"register_idxs should have the same length as the number of concepts, but got {len(ip_hidden_states_list)} and {len(self.textual_concept_idxs)}" | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) # [16, 4096, 40] | |
key = attn.head_to_batch_dim(key) # [16, 77, 40] | |
value = attn.head_to_batch_dim(value) # [16, 77, 40] | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
concept_mask_layer = self.concept_mask_layer | |
if len(global_concept_mask) == 0: | |
global_concept_mask = [None for _ in range(len(self.textual_concept_idxs))] | |
for i in range(len(self.textual_concept_idxs)): | |
ip_hidden_states = ip_hidden_states_list[i] | |
textual_concept_start_idx, textual_concept_end_idx = ( | |
self.textual_concept_idxs[i] | |
) | |
ip_key = self.to_k_ip(ip_hidden_states) | |
ip_value = self.to_v_ip(ip_hidden_states) | |
ip_key = attn.head_to_batch_dim(ip_key) # [16, 4, 40] | |
ip_value = attn.head_to_batch_dim(ip_value) # [16, 4, 40] | |
# attention_probs: [20/40, 4096, 77] | |
ip_attention_mask = attention_probs[ | |
:, :, textual_concept_start_idx:textual_concept_end_idx | |
] # [16, 4096, T] | |
ip_attention_mask = torch.mean( | |
ip_attention_mask, dim=-1, keepdim=True | |
) # [16, 4096, 1] | |
ip_attention_mask = attn.batch_to_head_dim( | |
ip_attention_mask | |
) # [2, 4096, 8] | |
ip_attention_mask = torch.mean( | |
ip_attention_mask, dim=-1, keepdim=True | |
) # [2, 4096, 1] | |
ip_attention_mask = ip_attention_mask / ( | |
torch.amax(ip_attention_mask, dim=-2, keepdim=True) + 1e-6 | |
) | |
ip_attention_mask = ip_attention_mask[1:2] # (use the classifier one) | |
# Visualization | |
if self.name not in attn_mask_logs: | |
attn_mask_logs[self.name] = [] | |
text_attn_map_logs[self.name] = [] | |
image_attn_map_logs[self.name] = [] | |
attn_mask_logs[self.name].append( | |
ip_attention_mask.detach().cpu().numpy()[0, :, 0] | |
) | |
text_attn_map_logs[self.name].append( | |
ip_attention_mask.detach().cpu().numpy()[0, :, 0] | |
) | |
if self.global_masking and ( | |
self.name == concept_mask_layer[0] | |
): | |
global_concept_mask[i] = ip_attention_mask | |
if ( | |
self.global_masking | |
and self.name != concept_mask_layer[0] | |
and global_concept_mask[i] is not None | |
): | |
original_dim = int(global_concept_mask[i].shape[1] ** 0.5) | |
target_dim = int(hidden_states.shape[1] ** 0.5) | |
global_concept_mask_2d = global_concept_mask[i].view( | |
global_concept_mask[i].shape[0], 1, original_dim, original_dim | |
) | |
resized_global_concept_mask_2d = F.interpolate( | |
global_concept_mask_2d, | |
size=(target_dim, target_dim), | |
mode="nearest", | |
) | |
resized_global_concept_mask = resized_global_concept_mask_2d.view( | |
global_concept_mask[i].shape[0], -1, 1 | |
) | |
ip_attention_mask = resized_global_concept_mask | |
ip_attention_probs = attn.get_attention_scores( | |
query, ip_key, None | |
) # [16, 4096, 4] | |
# Visualization | |
ip_attention_map = attention_probs[:, :, 15:16] # [16, 4096, T] | |
ip_attention_map = torch.mean( | |
ip_attention_map, dim=-1, keepdim=True | |
) # [16, 4096, 1] | |
ip_attention_map = torch.mean( | |
ip_attention_map, dim=-1, keepdim=True | |
) # [16, 4096, 1] | |
ip_attention_map = attn.batch_to_head_dim(ip_attention_map) # [2, 4096, 8] | |
ip_attention_map = torch.mean( | |
ip_attention_map, dim=-1, keepdim=True | |
) # [2, 4096, 1] | |
ip_attention_map = ip_attention_map / ( | |
torch.amax(ip_attention_map, dim=-2, keepdim=True) + 1e-6 | |
) | |
ip_attention_map = ip_attention_map[1:2] # (use the classifier one) | |
image_attn_map_logs[self.name].append( | |
ip_attention_map.detach().cpu().numpy()[0, :, 0] | |
) | |
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value) # [16, 4096, 40] | |
ip_hidden_states = attn.batch_to_head_dim( | |
ip_hidden_states | |
) # [2, 4096, 320] | |
ip_hidden_states = ip_hidden_states * ip_attention_mask | |
if self.adaptive_scale_mask: | |
raise ValueError("adaptive_scale_mask is deprecated already") | |
hidden_states += self.scale * ip_hidden_states | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class AttnProcessor2_0(torch.nn.Module): | |
r""" | |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
""" | |
def __init__( | |
self, | |
hidden_size=None, | |
cross_attention_dim=None, | |
): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1] | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
elif attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class IPAttnProcessor2_0(torch.nn.Module): | |
r""" | |
Attention processor for IP-Adapater for PyTorch 2.0. | |
Args: | |
hidden_size (`int`): | |
The hidden size of the attention layer. | |
cross_attention_dim (`int`): | |
The number of channels in the `encoder_hidden_states`. | |
scale (`float`, defaults to 1.0): | |
the weight scale of image prompt. | |
num_tokens (`int`, defaults to 4 when do ip_adapter_plus it should be 16): | |
The context length of the image features. | |
""" | |
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0, num_tokens=4): | |
super().__init__() | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
self.hidden_size = hidden_size | |
self.cross_attention_dim = cross_attention_dim | |
self.scale = scale | |
self.num_tokens = num_tokens | |
self.to_k_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
self.to_v_ip = nn.Linear( | |
cross_attention_dim or hidden_size, hidden_size, bias=False | |
) | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1] | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
else: | |
# get encoder_hidden_states, ip_hidden_states | |
end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
encoder_hidden_states, ip_hidden_states = ( | |
encoder_hidden_states[:, :end_pos, :], | |
encoder_hidden_states[:, end_pos:, :], | |
) | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
# for ip-adapter | |
ip_key = self.to_k_ip(ip_hidden_states) | |
ip_value = self.to_v_ip(ip_hidden_states) | |
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
ip_hidden_states = F.scaled_dot_product_attention( | |
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False | |
) | |
with torch.no_grad(): | |
self.attn_map = query @ ip_key.transpose(-2, -1).softmax(dim=-1) | |
# print(self.attn_map.shape) | |
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
ip_hidden_states = ip_hidden_states.to(query.dtype) | |
hidden_states = hidden_states + self.scale * ip_hidden_states | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
## for controlnet | |
class CNAttnProcessor: | |
r""" | |
Default processor for performing attention-related computations. | |
""" | |
def __init__(self, num_tokens=4): | |
self.num_tokens = num_tokens | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
else: | |
end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
query = attn.head_to_batch_dim(query) | |
key = attn.head_to_batch_dim(key) | |
value = attn.head_to_batch_dim(value) | |
attention_probs = attn.get_attention_scores(query, key, attention_mask) | |
hidden_states = torch.bmm(attention_probs, value) | |
hidden_states = attn.batch_to_head_dim(hidden_states) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |
class CNAttnProcessor2_0: | |
r""" | |
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). | |
""" | |
def __init__(self, num_tokens=4): | |
if not hasattr(F, "scaled_dot_product_attention"): | |
raise ImportError( | |
"AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0." | |
) | |
self.num_tokens = num_tokens | |
def __call__( | |
self, | |
attn, | |
hidden_states, | |
encoder_hidden_states=None, | |
attention_mask=None, | |
temb=None, | |
*args, | |
**kwargs, | |
): | |
residual = hidden_states | |
if attn.spatial_norm is not None: | |
hidden_states = attn.spatial_norm(hidden_states, temb) | |
input_ndim = hidden_states.ndim | |
if input_ndim == 4: | |
batch_size, channel, height, width = hidden_states.shape | |
hidden_states = hidden_states.view( | |
batch_size, channel, height * width | |
).transpose(1, 2) | |
batch_size, sequence_length, _ = ( | |
hidden_states.shape | |
if encoder_hidden_states is None | |
else encoder_hidden_states.shape | |
) | |
if attention_mask is not None: | |
attention_mask = attn.prepare_attention_mask( | |
attention_mask, sequence_length, batch_size | |
) | |
# scaled_dot_product_attention expects attention_mask shape to be | |
# (batch, heads, source_length, target_length) | |
attention_mask = attention_mask.view( | |
batch_size, attn.heads, -1, attention_mask.shape[-1] | |
) | |
if attn.group_norm is not None: | |
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( | |
1, 2 | |
) | |
query = attn.to_q(hidden_states) | |
if encoder_hidden_states is None: | |
encoder_hidden_states = hidden_states | |
else: | |
end_pos = encoder_hidden_states.shape[1] - self.num_tokens | |
encoder_hidden_states = encoder_hidden_states[:, :end_pos] # only use text | |
if attn.norm_cross: | |
encoder_hidden_states = attn.norm_encoder_hidden_states( | |
encoder_hidden_states | |
) | |
key = attn.to_k(encoder_hidden_states) | |
value = attn.to_v(encoder_hidden_states) | |
inner_dim = key.shape[-1] | |
head_dim = inner_dim // attn.heads | |
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) | |
# the output of sdp = (batch, num_heads, seq_len, head_dim) | |
# TODO: add support for attn.scale when we move to Torch 2.1 | |
hidden_states = F.scaled_dot_product_attention( | |
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False | |
) | |
hidden_states = hidden_states.transpose(1, 2).reshape( | |
batch_size, -1, attn.heads * head_dim | |
) | |
hidden_states = hidden_states.to(query.dtype) | |
# linear proj | |
hidden_states = attn.to_out[0](hidden_states) | |
# dropout | |
hidden_states = attn.to_out[1](hidden_states) | |
if input_ndim == 4: | |
hidden_states = hidden_states.transpose(-1, -2).reshape( | |
batch_size, channel, height, width | |
) | |
if attn.residual_connection: | |
hidden_states = hidden_states + residual | |
hidden_states = hidden_states / attn.rescale_output_factor | |
return hidden_states | |