|
import torch |
|
from einops import rearrange |
|
|
|
|
|
def low_version_attention(query, key, value, attn_bias=None): |
|
scale = 1 / query.shape[-1] ** 0.5 |
|
query = query * scale |
|
attn = torch.matmul(query, key.transpose(-2, -1)) |
|
if attn_bias is not None: |
|
attn = attn + attn_bias |
|
attn = attn.softmax(-1) |
|
return attn @ value |
|
|
|
|
|
class Attention(torch.nn.Module): |
|
|
|
def __init__(self, q_dim, num_heads, head_dim, kv_dim=None, bias_q=False, bias_kv=False, bias_out=False): |
|
super().__init__() |
|
dim_inner = head_dim * num_heads |
|
kv_dim = kv_dim if kv_dim is not None else q_dim |
|
self.num_heads = num_heads |
|
self.head_dim = head_dim |
|
|
|
self.to_q = torch.nn.Linear(q_dim, dim_inner, bias=bias_q) |
|
self.to_k = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) |
|
self.to_v = torch.nn.Linear(kv_dim, dim_inner, bias=bias_kv) |
|
self.to_out = torch.nn.Linear(dim_inner, q_dim, bias=bias_out) |
|
|
|
def torch_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
|
|
batch_size = encoder_hidden_states.shape[0] |
|
|
|
q = self.to_q(hidden_states) |
|
k = self.to_k(encoder_hidden_states) |
|
v = self.to_v(encoder_hidden_states) |
|
|
|
q = q.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) |
|
k = k.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) |
|
v = v.view(batch_size, -1, self.num_heads, self.head_dim).transpose(1, 2) |
|
|
|
hidden_states = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask) |
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, self.num_heads * self.head_dim) |
|
hidden_states = hidden_states.to(q.dtype) |
|
|
|
hidden_states = self.to_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
def xformers_forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): |
|
if encoder_hidden_states is None: |
|
encoder_hidden_states = hidden_states |
|
|
|
q = self.to_q(hidden_states) |
|
k = self.to_k(encoder_hidden_states) |
|
v = self.to_v(encoder_hidden_states) |
|
|
|
q = rearrange(q, "b f (n d) -> (b n) f d", n=self.num_heads) |
|
k = rearrange(k, "b f (n d) -> (b n) f d", n=self.num_heads) |
|
v = rearrange(v, "b f (n d) -> (b n) f d", n=self.num_heads) |
|
|
|
if attn_mask is not None: |
|
hidden_states = low_version_attention(q, k, v, attn_bias=attn_mask) |
|
else: |
|
import xformers.ops as xops |
|
hidden_states = xops.memory_efficient_attention(q, k, v) |
|
hidden_states = rearrange(hidden_states, "(b n) f d -> b f (n d)", n=self.num_heads) |
|
|
|
hidden_states = hidden_states.to(q.dtype) |
|
hidden_states = self.to_out(hidden_states) |
|
|
|
return hidden_states |
|
|
|
def forward(self, hidden_states, encoder_hidden_states=None, attn_mask=None): |
|
return self.torch_forward(hidden_states, encoder_hidden_states=encoder_hidden_states, attn_mask=attn_mask) |