|
from functools import partial
|
|
|
|
import torch
|
|
import torch.nn.functional as F
|
|
from einops import rearrange
|
|
from torch import nn, einsum
|
|
from torch.utils.checkpoint import checkpoint
|
|
|
|
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def default(val, d):
|
|
return val if exists(val) else d
|
|
|
|
|
|
|
|
|
|
def attention(
|
|
q, k, v,
|
|
mask=None,
|
|
causal=False,
|
|
attn_bias=None,
|
|
**kwargs
|
|
):
|
|
scale = q.shape[-1] ** -0.5
|
|
q = q * scale
|
|
|
|
sim = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
|
|
if exists(attn_bias):
|
|
sim = sim + attn_bias
|
|
|
|
mask_value = -torch.finfo(sim.dtype).max
|
|
|
|
if exists(mask):
|
|
if mask.ndim == 2:
|
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
|
sim = sim.masked_fill(~mask, mask_value)
|
|
|
|
if causal:
|
|
i, j = sim.shape[-2:]
|
|
mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1)
|
|
sim = sim.masked_fill(mask, mask_value)
|
|
|
|
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
|
|
attn = sim.softmax(dim=-1)
|
|
|
|
out = einsum('b h i j, b h j d -> b h i d', attn, v)
|
|
return out
|
|
|
|
|
|
|
|
|
|
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
|
|
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[
|
|
-2], q.device
|
|
|
|
weight = einsum('b h i d, b h j d -> b h i j', q, k)
|
|
|
|
if exists(attn_bias_chunk):
|
|
weight = weight + attn_bias_chunk
|
|
|
|
mask_value = -torch.finfo(weight.dtype).max
|
|
|
|
if exists(mask):
|
|
mask = rearrange(mask, 'b j -> b 1 1 j')
|
|
weight = weight.masked_fill(~mask, mask_value)
|
|
|
|
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
|
|
causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype=torch.bool, device=device).triu(
|
|
q_start_index - k_start_index + 1)
|
|
weight = weight.masked_fill(causal_mask, mask_value)
|
|
|
|
weight_max = weight.amax(dim=-1, keepdim=True).detach()
|
|
weight = weight - weight_max
|
|
|
|
exp_weight = weight.exp()
|
|
|
|
exp_weight = F.dropout(exp_weight, p=dropout)
|
|
|
|
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
|
|
|
|
return exp_weight.sum(dim=-1), weighted_value, rearrange(weight_max, '... 1 -> ...')
|
|
|
|
|
|
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
|
|
|
|
|
|
def memory_efficient_attention(
|
|
q, k, v,
|
|
mask=None,
|
|
causal=False,
|
|
attn_bias=None,
|
|
q_bucket_size=512,
|
|
k_bucket_size=1024,
|
|
eps=1e-8,
|
|
dropout=0.,
|
|
training=False
|
|
):
|
|
scale = q.shape[-1] ** -0.5
|
|
q = q * scale
|
|
|
|
|
|
|
|
needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
|
|
summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
|
|
|
|
|
|
|
|
q_chunks = q.split(q_bucket_size, dim=-2)
|
|
k_chunks = k.split(k_bucket_size, dim=-2)
|
|
v_chunks = v.split(k_bucket_size, dim=-2)
|
|
mask_chunks = mask.split(k_bucket_size, dim=-1) if exists(mask) else ((None,) * len(k_chunks))
|
|
|
|
if exists(attn_bias):
|
|
i, j = attn_bias.shape[-2:]
|
|
attn_bias_chunks = attn_bias.split(q_bucket_size, dim=-2)
|
|
attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim=-1), attn_bias_chunks))
|
|
|
|
|
|
|
|
out = []
|
|
for q_index, q_chunk in enumerate(q_chunks):
|
|
exp_weights = []
|
|
weighted_values = []
|
|
weight_maxes = []
|
|
|
|
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
|
|
q_start_index = q_index * q_bucket_size
|
|
k_start_index = k_index * k_bucket_size
|
|
|
|
if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
|
|
|
|
continue
|
|
|
|
attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
|
|
|
|
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
|
|
q_chunk,
|
|
k_chunk,
|
|
v_chunk,
|
|
mask_chunk,
|
|
attn_bias_chunk,
|
|
causal,
|
|
(q_start_index, k_start_index),
|
|
dropout if training else 0.
|
|
)
|
|
|
|
exp_weights.append(exp_weight_chunk)
|
|
weighted_values.append(weighted_value_chunk)
|
|
weight_maxes.append(weight_max_chunk)
|
|
|
|
weight_maxes = torch.stack(weight_maxes, dim=-1)
|
|
|
|
weighted_values = torch.stack(weighted_values, dim=-1)
|
|
exp_weights = torch.stack(exp_weights, dim=-1)
|
|
|
|
global_max = weight_maxes.amax(dim=-1, keepdim=True)
|
|
renorm_factor = (weight_maxes - global_max).exp().detach()
|
|
|
|
exp_weights = exp_weights * renorm_factor
|
|
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
|
|
|
|
all_values = weighted_values.sum(dim=-1)
|
|
all_weights = exp_weights.sum(dim=-1)
|
|
|
|
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
|
|
out.append(normalized_values)
|
|
|
|
return torch.cat(out, dim=-2)
|
|
|
|
|
|
|
|
|
|
class Attention(nn.Module):
|
|
def __init__(
|
|
self,
|
|
*,
|
|
dim,
|
|
heads=8,
|
|
dim_head=64,
|
|
dropout=0.,
|
|
causal=False,
|
|
memory_efficient=False,
|
|
q_bucket_size=512,
|
|
k_bucket_size=1024
|
|
):
|
|
super().__init__()
|
|
self.heads = heads
|
|
self.causal = causal
|
|
self.dropout = dropout
|
|
inner_dim = heads * dim_head
|
|
|
|
self.to_q = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_k = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_v = nn.Linear(dim, inner_dim, bias=False)
|
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
|
|
|
|
|
|
|
self.memory_efficient = memory_efficient
|
|
self.q_bucket_size = q_bucket_size
|
|
self.k_bucket_size = k_bucket_size
|
|
|
|
def forward(
|
|
self,
|
|
q, k, v,
|
|
mask=None,
|
|
attn_bias=None,
|
|
memory_efficient=None,
|
|
q_bucket_size=None,
|
|
k_bucket_size=None,
|
|
):
|
|
memory_efficient = default(memory_efficient, self.memory_efficient)
|
|
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
|
|
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
|
|
|
|
h = self.heads
|
|
|
|
q = self.to_q(q)
|
|
k = self.to_k(k)
|
|
v = self.to_v(v)
|
|
|
|
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
|
|
|
|
attn_fn = attention if not memory_efficient else memory_efficient_attention
|
|
|
|
out = attn_fn(q, k, v, mask=mask, attn_bias=attn_bias, causal=self.causal, q_bucket_size=q_bucket_size,
|
|
k_bucket_size=k_bucket_size, dropout=self.dropout, training=self.training)
|
|
|
|
out = rearrange(out, 'b h n d -> b n (h d)')
|
|
return self.to_out(out)
|
|
|
|
|
|
class MemoryEffTransformer(nn.Module):
|
|
def __init__(self,
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward=2048,
|
|
dropout=0.1,
|
|
activation=F.relu,
|
|
layer_norm_eps=1e-5):
|
|
super().__init__()
|
|
dim_head = d_model // nhead
|
|
self.self_attn = Attention(dim=d_model,
|
|
heads=nhead,
|
|
dim_head=dim_head,
|
|
memory_efficient=True)
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
|
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.dropout3 = nn.Dropout(dropout)
|
|
|
|
self.activation = activation
|
|
|
|
def forward(self, x, need_mean=False):
|
|
if isinstance(x, tuple):
|
|
q, k, v = x
|
|
else:
|
|
q, k, v = x, x, x
|
|
tmp = self.self_attn(q, k, v)
|
|
if need_mean:
|
|
num_query, embed_dims, bs, num_bev_queue = (q.shape[1],
|
|
q.shape[2],
|
|
q.shape[0] // 2,
|
|
2)
|
|
tmp = tmp.view(num_query, embed_dims, bs, num_bev_queue)
|
|
tmp = tmp.mean(-1)
|
|
tmp = tmp.permute(2, 0, 1)
|
|
q = q[bs:]
|
|
assert(q.shape[0]==bs and q.shape[1]==num_query and q.shape[2]==embed_dims)
|
|
q = self.norm1(q + self.dropout1(tmp))
|
|
tmp = self.linear2(self.dropout(self.activation(self.linear1(q))))
|
|
q = self.norm3(q + self.dropout3(tmp))
|
|
|
|
return q
|
|
|