lkllkl's picture
Upload folder using huggingface_hub
da2e2ac verified
raw
history blame
9.27 kB
from functools import partial
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn, einsum
from torch.utils.checkpoint import checkpoint
# helper functions
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
# regular attention
def attention(
q, k, v,
mask=None,
causal=False,
attn_bias=None,
**kwargs
):
scale = q.shape[-1] ** -0.5
q = q * scale
sim = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(attn_bias):
sim = sim + attn_bias
mask_value = -torch.finfo(sim.dtype).max
if exists(mask):
if mask.ndim == 2:
mask = rearrange(mask, 'b j -> b 1 1 j')
sim = sim.masked_fill(~mask, mask_value)
if causal:
i, j = sim.shape[-2:]
mask = torch.ones(i, j, device=q.device, dtype=torch.bool).triu(j - i + 1)
sim = sim.masked_fill(mask, mask_value)
sim = sim - sim.amax(dim=-1, keepdim=True).detach()
attn = sim.softmax(dim=-1)
out = einsum('b h i j, b h j d -> b h i d', attn, v)
return out
# memory efficient attention
def summarize_qkv_chunk(q, k, v, mask, attn_bias_chunk, causal, qk_start_indices, dropout):
q_start_index, k_start_index, q_chunk_size, k_chunk_size, device = *qk_start_indices, q.shape[-2], k.shape[
-2], q.device
weight = einsum('b h i d, b h j d -> b h i j', q, k)
if exists(attn_bias_chunk):
weight = weight + attn_bias_chunk
mask_value = -torch.finfo(weight.dtype).max
if exists(mask):
mask = rearrange(mask, 'b j -> b 1 1 j')
weight = weight.masked_fill(~mask, mask_value)
if causal and q_start_index < (k_start_index + k_chunk_size - 1):
causal_mask = torch.ones((q_chunk_size, k_chunk_size), dtype=torch.bool, device=device).triu(
q_start_index - k_start_index + 1)
weight = weight.masked_fill(causal_mask, mask_value)
weight_max = weight.amax(dim=-1, keepdim=True).detach()
weight = weight - weight_max
exp_weight = weight.exp()
exp_weight = F.dropout(exp_weight, p=dropout)
weighted_value = einsum('b h i j, b h j d -> b h i d', exp_weight, v)
return exp_weight.sum(dim=-1), weighted_value, rearrange(weight_max, '... 1 -> ...')
checkpointed_summarize_qkv_chunk = partial(checkpoint, summarize_qkv_chunk)
def memory_efficient_attention(
q, k, v,
mask=None,
causal=False,
attn_bias=None,
q_bucket_size=512,
k_bucket_size=1024,
eps=1e-8,
dropout=0.,
training=False
):
scale = q.shape[-1] ** -0.5
q = q * scale
# function
needs_backwards = q.requires_grad or k.requires_grad or v.requires_grad
summarize_qkv_fn = checkpointed_summarize_qkv_chunk if needs_backwards else summarize_qkv_chunk
# chunk all the inputs
q_chunks = q.split(q_bucket_size, dim=-2)
k_chunks = k.split(k_bucket_size, dim=-2)
v_chunks = v.split(k_bucket_size, dim=-2)
mask_chunks = mask.split(k_bucket_size, dim=-1) if exists(mask) else ((None,) * len(k_chunks))
if exists(attn_bias):
i, j = attn_bias.shape[-2:]
attn_bias_chunks = attn_bias.split(q_bucket_size, dim=-2)
attn_bias_chunks = list(map(lambda t: t.split(k_bucket_size, dim=-1), attn_bias_chunks))
# loop through all chunks and accumulate
out = []
for q_index, q_chunk in enumerate(q_chunks):
exp_weights = []
weighted_values = []
weight_maxes = []
for k_index, (k_chunk, v_chunk, mask_chunk) in enumerate(zip(k_chunks, v_chunks, mask_chunks)):
q_start_index = q_index * q_bucket_size
k_start_index = k_index * k_bucket_size
if causal and k_start_index > (q_start_index + q_chunk.shape[-2] - 1):
# if chunk is to be all masked out causally, skip
continue
attn_bias_chunk = attn_bias_chunks[q_index][k_index] if exists(attn_bias) else None
exp_weight_chunk, weighted_value_chunk, weight_max_chunk = summarize_qkv_fn(
q_chunk,
k_chunk,
v_chunk,
mask_chunk,
attn_bias_chunk,
causal,
(q_start_index, k_start_index),
dropout if training else 0.
)
exp_weights.append(exp_weight_chunk)
weighted_values.append(weighted_value_chunk)
weight_maxes.append(weight_max_chunk)
weight_maxes = torch.stack(weight_maxes, dim=-1)
weighted_values = torch.stack(weighted_values, dim=-1)
exp_weights = torch.stack(exp_weights, dim=-1)
global_max = weight_maxes.amax(dim=-1, keepdim=True)
renorm_factor = (weight_maxes - global_max).exp().detach()
exp_weights = exp_weights * renorm_factor
weighted_values = weighted_values * rearrange(renorm_factor, '... c -> ... 1 c')
all_values = weighted_values.sum(dim=-1)
all_weights = exp_weights.sum(dim=-1)
normalized_values = all_values / (rearrange(all_weights, '... -> ... 1') + eps)
out.append(normalized_values)
return torch.cat(out, dim=-2)
# main class
class Attention(nn.Module):
def __init__(
self,
*,
dim,
heads=8,
dim_head=64,
dropout=0.,
causal=False,
memory_efficient=False,
q_bucket_size=512,
k_bucket_size=1024
):
super().__init__()
self.heads = heads
self.causal = causal
self.dropout = dropout
inner_dim = heads * dim_head
self.to_q = nn.Linear(dim, inner_dim, bias=False)
self.to_k = nn.Linear(dim, inner_dim, bias=False)
self.to_v = nn.Linear(dim, inner_dim, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False)
# memory efficient attention related parameters
# can be overriden on forward
self.memory_efficient = memory_efficient
self.q_bucket_size = q_bucket_size
self.k_bucket_size = k_bucket_size
def forward(
self,
q, k, v,
mask=None,
attn_bias=None,
memory_efficient=None,
q_bucket_size=None,
k_bucket_size=None,
):
memory_efficient = default(memory_efficient, self.memory_efficient)
q_bucket_size = default(q_bucket_size, self.q_bucket_size)
k_bucket_size = default(k_bucket_size, self.k_bucket_size)
h = self.heads
q = self.to_q(q)
k = self.to_k(k)
v = self.to_v(v)
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), (q, k, v))
attn_fn = attention if not memory_efficient else memory_efficient_attention
out = attn_fn(q, k, v, mask=mask, attn_bias=attn_bias, causal=self.causal, q_bucket_size=q_bucket_size,
k_bucket_size=k_bucket_size, dropout=self.dropout, training=self.training)
out = rearrange(out, 'b h n d -> b n (h d)')
return self.to_out(out)
class MemoryEffTransformer(nn.Module):
def __init__(self,
d_model,
nhead,
dim_feedforward=2048,
dropout=0.1,
activation=F.relu,
layer_norm_eps=1e-5):
super().__init__()
dim_head = d_model // nhead
self.self_attn = Attention(dim=d_model,
heads=nhead,
dim_head=dim_head,
memory_efficient=True)
self.linear1 = nn.Linear(d_model, dim_feedforward)
self.dropout = nn.Dropout(dropout)
self.linear2 = nn.Linear(dim_feedforward, d_model)
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps)
self.dropout1 = nn.Dropout(dropout)
self.dropout3 = nn.Dropout(dropout)
self.activation = activation
def forward(self, x, need_mean=False):
if isinstance(x, tuple):
q, k, v = x
else:
q, k, v = x, x, x
tmp = self.self_attn(q, k, v)
if need_mean:
num_query, embed_dims, bs, num_bev_queue = (q.shape[1],
q.shape[2],
q.shape[0] // 2,
2)
tmp = tmp.view(num_query, embed_dims, bs, num_bev_queue)
tmp = tmp.mean(-1)
tmp = tmp.permute(2, 0, 1)
q = q[bs:]
assert(q.shape[0]==bs and q.shape[1]==num_query and q.shape[2]==embed_dims)
q = self.norm1(q + self.dropout1(tmp))
tmp = self.linear2(self.dropout(self.activation(self.linear1(q))))
q = self.norm3(q + self.dropout3(tmp))
return q